initial import for Open Source 🎉

This commit is contained in:
Jimmy Zelinskie 2019-11-12 11:09:47 -05:00
parent 1898c361f3
commit 9c0dd3b722
2048 changed files with 218743 additions and 0 deletions

0
test/__init__.py Normal file
View file

6
test/analytics.py Normal file
View file

@ -0,0 +1,6 @@
class FakeMixpanel(object):
def track(*args, **kwargs):
pass
def init_app(app):
return FakeMixpanel()

View file

@ -0,0 +1,2 @@
[Service]
Environment=DOCKER_OPTS='--insecure-registry="0.0.0.0/0"'

View file

@ -0,0 +1,5 @@
FROM quay.io/quay/busybox
RUN date > somefile
RUN date +%s%N > anotherfile
RUN date +"%T.%N" > thirdfile
RUN echo "testing 123" > testfile

0
test/clients/__init__.py Normal file
View file

131
test/clients/client.py Normal file
View file

@ -0,0 +1,131 @@
from abc import ABCMeta, abstractmethod
from collections import namedtuple
from six import add_metaclass
Command = namedtuple('Command', ['command'])
# NOTE: FileCopy is done via `scp`, instead of `ssh` which is how Command is run.
FileCopy = namedtuple('FileCopy', ['source', 'destination'])
@add_metaclass(ABCMeta)
class Client(object):
""" Client defines the interface for all clients being tested. """
@abstractmethod
def setup_client(self, registry_host, verify_tls):
""" Returns the commands necessary to setup the client inside the VM.
"""
@abstractmethod
def populate_test_image(self, registry_host, namespace, name):
""" Returns the commands necessary to populate the test image. """
@abstractmethod
def print_version(self):
""" Returns the commands necessary to print the version of the client. """
@abstractmethod
def login(self, registry_host, username, password):
""" Returns the commands necessary to login. """
@abstractmethod
def push(self, registry_host, namespace, name):
""" Returns the commands necessary to test pushing. """
@abstractmethod
def pre_pull_cleanup(self, registry_host, namespace, name):
""" Returns the commands necessary to cleanup before pulling. """
@abstractmethod
def pull(self, registry_host, namespace, name):
""" Returns the commands necessary to test pulling. """
@abstractmethod
def verify(self, registry_host, namespace, name):
""" Returns the commands necessary to verify the pulled image. """
class DockerClient(Client):
def __init__(self, requires_v1=False, requires_email=False):
self.requires_v1 = requires_v1
self.requires_email = requires_email
def setup_client(self, registry_host, verify_tls):
if not verify_tls:
cp_command = ('sudo cp /home/core/50-insecure-registry.conf ' +
'/etc/systemd/system/docker.service.d/50-insecure-registry.conf')
yield Command('sudo mkdir -p /etc/systemd/system/docker.service.d/')
yield FileCopy('50-insecure-registry.conf', '/home/core')
yield Command(cp_command)
yield Command('sudo systemctl daemon-reload')
yield FileCopy('Dockerfile.test', '/home/core/Dockerfile')
def populate_test_image(self, registry_host, namespace, name):
if self.requires_v1:
# These versions of Docker don't support the new TLS cert on quay.io, so we need to pull
# from v1.quay.io and then retag so the build works.
yield Command('docker pull v1.quay.io/quay/busybox')
yield Command('docker tag v1.quay.io/quay/busybox quay.io/quay/busybox')
yield Command('docker build -t %s/%s/%s .' % (registry_host, namespace, name))
def print_version(self):
yield Command('docker version')
def login(self, registry_host, username, password):
email_param = ""
if self.requires_email:
# cli will block forever if email is not set for version under 1.10.3
email_param = "--email=none "
yield Command('docker login --username=%s --password=%s %s %s' %
(username, password, email_param, registry_host))
def push(self, registry_host, namespace, name):
yield Command('docker push %s/%s/%s' % (registry_host, namespace, name))
def pre_pull_cleanup(self, registry_host, namespace, name):
prefix = 'v1.' if self.requires_v1 else ''
yield Command('docker rmi -f %s/%s/%s' % (registry_host, namespace, name))
yield Command('docker rmi -f %squay.io/quay/busybox' % prefix)
def pull(self, registry_host, namespace, name):
yield Command('docker pull %s/%s/%s' % (registry_host, namespace, name))
def verify(self, registry_host, namespace, name):
yield Command('docker run %s/%s/%s echo testfile' % (registry_host, namespace, name))
class PodmanClient(Client):
def __init__(self):
self.verify_tls = False
def setup_client(self, registry_host, verify_tls):
yield FileCopy('Dockerfile.test', '/home/vagrant/Dockerfile')
self.verify_tls = verify_tls
def populate_test_image(self, registry_host, namespace, name):
yield Command('sudo podman build -t %s/%s/%s /home/vagrant/' % (registry_host, namespace, name))
def print_version(self):
yield Command('sudo podman version')
def login(self, registry_host, username, password):
yield Command('sudo podman login --tls-verify=%s --username=%s --password=%s %s' %
(self.verify_tls, username, password, registry_host))
def push(self, registry_host, namespace, name):
yield Command('sudo podman push --tls-verify=%s %s/%s/%s' % (self.verify_tls, registry_host, namespace, name))
def pre_pull_cleanup(self, registry_host, namespace, name):
yield Command('sudo podman rmi -f %s/%s/%s' % (registry_host, namespace, name))
yield Command('sudo podman rmi -f quay.io/quay/busybox')
def pull(self, registry_host, namespace, name):
yield Command('sudo podman pull --tls-verify=%s %s/%s/%s' % (self.verify_tls, registry_host, namespace, name))
def verify(self, registry_host, namespace, name):
yield Command('sudo podman run %s/%s/%s echo testfile' % (registry_host, namespace, name))

View file

@ -0,0 +1,278 @@
import os
import subprocess
import sys
import time
import unicodedata
from threading import Thread
from termcolor import colored
from client import DockerClient, PodmanClient, Command, FileCopy
def remove_control_characters(s):
return "".join(ch for ch in unicode(s, 'utf-8') if unicodedata.category(ch)[0] != "C")
# These tuples are the box&version and the client to use.
BOXES = [
("kleesc/centos7-podman --box-version=0.11.1.1", PodmanClient()), # podman 0.11.1.1
("kleesc/coreos --box-version=1911.4.0", DockerClient()), # docker 18.06.1
("kleesc/coreos --box-version=1800.7.0", DockerClient()), # docker 18.03.1
("kleesc/coreos --box-version=1688.5.3", DockerClient()), # docker 17.12.1
("kleesc/coreos --box-version=1632.3.0", DockerClient()), # docker 17.09.1
("kleesc/coreos --box-version=1576.5.0", DockerClient()), # docker 17.09.0
("kleesc/coreos --box-version=1520.9.0", DockerClient()), # docker 1.12.6
("kleesc/coreos --box-version=1235.6.0", DockerClient()), # docker 1.12.3
("kleesc/coreos --box-version=1185.5.0", DockerClient()), # docker 1.11.2
("kleesc/coreos --box-version=1122.3.0", DockerClient(requires_email=True)), # docker 1.10.3
("kleesc/coreos --box-version=899.17.0", DockerClient(requires_email=True)), # docker 1.9.1
("kleesc/coreos --box-version=835.13.0", DockerClient(requires_email=True)), # docker 1.8.3
("kleesc/coreos --box-version=766.5.0", DockerClient(requires_email=True)), # docker 1.7.1
("kleesc/coreos --box-version=717.3.0", DockerClient(requires_email=True)), # docker 1.6.2
("kleesc/coreos --box-version=647.2.0", DockerClient(requires_email=True)), # docker 1.5.0
("kleesc/coreos --box-version=557.2.0", DockerClient(requires_email=True)), # docker 1.4.1
("kleesc/coreos --box-version=522.6.0", DockerClient(requires_email=True)), # docker 1.3.3
("yungsang/coreos --box-version=1.3.7", DockerClient(requires_email=True)), # docker 1.3.2
("yungsang/coreos --box-version=1.2.9", DockerClient(requires_email=True)), # docker 1.2.0
("yungsang/coreos --box-version=1.1.5", DockerClient(requires_email=True)), # docker 1.1.2
("yungsang/coreos --box-version=1.0.0", DockerClient(requires_email=True)), # docker 1.0.1
("yungsang/coreos --box-version=0.9.10", DockerClient(requires_email=True)), # docker 1.0.0
("yungsang/coreos --box-version=0.9.6", DockerClient(requires_email=True)), # docker 0.11.1
("yungsang/coreos --box-version=0.9.1", DockerClient(requires_v1=True, requires_email=True)), # docker 0.10.0
("yungsang/coreos --box-version=0.3.1", DockerClient(requires_v1=True, requires_email=True)), # docker 0.9.0
]
def _check_vagrant():
vagrant_command = 'vagrant'
vagrant = any(os.access(os.path.join(path, vagrant_command), os.X_OK)
for path in os.environ.get('PATH').split(':'))
vagrant_plugins = subprocess.check_output([vagrant_command, 'plugin', 'list'])
return (vagrant, 'vagrant-scp' in vagrant_plugins)
def _load_ca(box, ca_cert):
if 'coreos' in box:
yield FileCopy(ca_cert, '/home/core/ca.pem')
yield Command('sudo cp /home/core/ca.pem /etc/ssl/certs/ca.pem')
yield Command('sudo update-ca-certificates')
yield Command('sudo systemctl daemon-reload')
elif 'centos' in box:
yield FileCopy(ca_cert, '/home/vagrant/ca.pem')
yield Command('sudo cp /home/vagrant/ca.pem /etc/pki/ca-trust/source/anchors/')
yield Command('sudo update-ca-trust enable')
yield Command('sudo update-ca-trust extract')
else:
raise Exception("unknown box for loading CA cert")
# extra steps to initialize the system
def _init_system(box):
if 'coreos' in box:
# disable the update-engine so that it's easier to debug
yield Command('sudo systemctl stop update-engine')
class CommandFailedException(Exception):
pass
class SpinOutputter(Thread):
def __init__(self, initial_message):
super(SpinOutputter, self).__init__()
self.previous_line = ''
self.next_line = initial_message
self.running = True
self.daemon = True
@staticmethod
def spinning_cursor():
while 1:
for cursor in '|/-\\':
yield cursor
def set_next(self, text):
first_line = text.split('\n')[0].strip()
first_line = remove_control_characters(first_line)
self.next_line = first_line[:80]
def _clear_line(self):
sys.stdout.write('\r')
sys.stdout.write(' ' * (len(self.previous_line) + 2))
sys.stdout.flush()
sys.stdout.write('\r')
sys.stdout.flush()
self.previous_line = ''
def stop(self):
self._clear_line()
self.running = False
def run(self):
spinner = SpinOutputter.spinning_cursor()
while self.running:
self._clear_line()
sys.stdout.write('\r')
sys.stdout.flush()
sys.stdout.write(next(spinner))
sys.stdout.write(" ")
sys.stdout.write(colored(self.next_line, attrs=['dark']))
sys.stdout.flush()
self.previous_line = self.next_line
time.sleep(0.25)
def _run_and_wait(command, error_allowed=False):
# Run the command itself.
outputter = SpinOutputter('Running command %s' % command)
outputter.start()
output = ''
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
for line in iter(process.stdout.readline, ''):
output += line
outputter.set_next(line)
result = process.wait()
outputter.stop()
failed = result != 0 and not error_allowed
# vagrant scp doesn't report auth failure as non-0 exit
failed = failed or (len(command) > 1 and command[1] == 'scp' and
'authentication failures' in output)
if failed:
print colored('>>> Command `%s` Failed:' % command, 'red')
print output
raise CommandFailedException()
return output
def _indent(text, amount):
return ''.join((' ' * amount) + line for line in text.splitlines(True))
def scp_to_vagrant(source, destination):
'''scp_to_vagrant copies the file from source to destination in the default
vagrant box without vagrant scp, which may fail on some coreos boxes.
'''
config = _run_and_wait(['vagrant', 'ssh-config'])
config_lines = config.split('\n')
params = ['scp']
for i in xrange(len(config_lines)):
if 'Host default' in config_lines[i]:
config_i = i + 1
while config_i < len(config_lines):
if config_lines[config_i].startswith(' '):
params += ['-o', '='.join(config_lines[config_i].split())]
else:
break
config_i += 1
break
params.append(source)
params.append('core@localhost:' + destination)
return _run_and_wait(params)
def _run_commands(commands):
last_result = None
for command in commands:
if isinstance(command, Command):
last_result = _run_and_wait(['vagrant', 'ssh', '-c', command.command])
else:
try:
last_result = _run_and_wait(['vagrant', 'scp', command.source, command.destination])
except CommandFailedException as e:
print colored('>>> Retry FileCopy command without vagrant scp...', 'red')
# sometimes the vagrant scp fails because of invalid ssh configuration.
last_result = scp_to_vagrant(command.source, command.destination)
return last_result
def _run_box(box, client, registry, ca_cert):
vagrant, vagrant_scp = _check_vagrant()
if not vagrant:
print("vagrant command not found")
return
if not vagrant_scp:
print("vagrant-scp plugin not installed")
return
namespace = 'devtable'
repo_name = 'testrepo%s' % int(time.time())
username = 'devtable'
password = 'password'
print colored('>>> Box: %s' % box, attrs=['bold'])
print colored('>>> Starting box', 'yellow')
_run_and_wait(['vagrant', 'destroy', '-f'], error_allowed=True)
_run_and_wait(['rm', 'Vagrantfile'], error_allowed=True)
_run_and_wait(['vagrant', 'init'] + box.split(' '))
_run_and_wait(['vagrant', 'up', '--provider', 'virtualbox'])
_run_commands(_init_system(box))
if ca_cert:
print colored('>>> Setting up runtime with cert ' + ca_cert, 'yellow')
_run_commands(_load_ca(box, ca_cert))
_run_commands(client.setup_client(registry, verify_tls=True))
else:
print colored('>>> Setting up runtime with insecure HTTP(S)', 'yellow')
_run_commands(client.setup_client(registry, verify_tls=False))
print colored('>>> Client version', 'cyan')
runtime_version = _run_commands(client.print_version())
print _indent(runtime_version, 4)
print colored('>>> Populating test image', 'yellow')
_run_commands(client.populate_test_image(registry, namespace, repo_name))
print colored('>>> Testing login', 'cyan')
_run_commands(client.login(registry, username, password))
print colored('>>> Testing push', 'cyan')
_run_commands(client.push(registry, namespace, repo_name))
print colored('>>> Removing all images', 'yellow')
_run_commands(client.pre_pull_cleanup(registry, namespace, repo_name))
print colored('>>> Testing pull', 'cyan')
_run_commands(client.pull(registry, namespace, repo_name))
print colored('>>> Verifying', 'cyan')
_run_commands(client.verify(registry, namespace, repo_name))
print colored('>>> Tearing down box', 'magenta')
_run_and_wait(['vagrant', 'destroy', '-f'], error_allowed=True)
print colored('>>> Successfully tested box %s' % box, 'green')
print ""
def test_clients(registry='10.0.2.2:5000', ca_cert=''):
print colored('>>> Running against registry ', attrs=['bold']) + colored(registry, 'cyan')
for box, client in BOXES:
try:
_run_box(box, client, registry, ca_cert)
except CommandFailedException:
sys.exit(-1)
if __name__ == "__main__":
test_clients(sys.argv[1] if len(sys.argv) > 1 else '10.0.2.2:5000', sys.argv[2]
if len(sys.argv) > 2 else '')

38
test/conftest.py Normal file
View file

@ -0,0 +1,38 @@
from __future__ import print_function
import pytest
def pytest_collection_modifyitems(config, items):
"""
This adds a pytest marker that consistently shards all collected tests.
Use it like the following:
$ py.test -m shard_1_of_3
$ py.test -m shard_2_of_3
$ py.test -m shard_3_of_3
This code was originally adopted from the MIT-licensed ansible/molecule@9e7b79b:
Copyright (c) 2015-2018 Cisco Systems, Inc.
Copyright (c) 2018 Red Hat, Inc.
"""
mark_opt = config.getoption('-m')
if not mark_opt.startswith('shard_'):
return
desired_shard, _, total_shards = mark_opt[len('shard_'):].partition('_of_')
if not total_shards or not desired_shard:
return
desired_shard = int(desired_shard)
total_shards = int(total_shards)
if not 0 < desired_shard <= total_shards:
raise ValueError('desired_shard must be greater than 0 and not bigger than total_shards')
for test_counter, item in enumerate(items):
shard = test_counter%total_shards + 1
marker = getattr(pytest.mark, 'shard_{}_of_{}'.format(shard, total_shards))
item.add_marker(marker)
print('Running sharded test group #{} out of {}'.format(desired_shard, total_shards))

Binary file not shown.

Binary file not shown.

View file

@ -0,0 +1,19 @@
-----BEGIN PGP PUBLIC KEY BLOCK-----
Version: GnuPG v2
mQENBFTVMzABCAC8jcnCrNHKk0LgyZTdTFtf9Qm2bK27Y0EyyI8tWefUt4LhQRCA
14dksJVzqWBtpHJnqkYUwfoXZmdz4e9fSS1mmoiHlDwzkuNXx2J1HAnXSxgNMV1D
JQmfxhKQzFTgkTEN03txPZrOMrDNIZSw0gkAbiBGuQXk9/HNGbzdjkd3vk1GF7Vk
v1vITmWQG+QQi7H8zR1NYYuFQb5cdDDuOoQWHXNMIZmK27StZ6MUot3NlquZbs1q
5Gr1HHog0qx+0uYn441zghZ9R1JqaAig0V3eJ8UAbTIMZPO09UUBQKC7O7OgOX/H
92zGWGwkTMUqJNJUr/dj5ocQbpFk8X3yz+d9ABEBAAG0RFF1YXkuaW8gQUNJIENv
bnZlcnRlciAoQUNJIGNvbnZlcnNpb24gc2lnbmluZyBrZXkpIDxzdXBwb3J0QHF1
YXkuaW8+iQE5BBMBAgAjBQJU1TMwAhsDBwsJCAcDAgEGFQgCCQoLBBYCAwECHgEC
F4AACgkQRjIEfu6zIiHo9Af+MCE4bUOrQ6yrHSPHebHwSARULaTB0Rlj4BAXlv+A
nUJDaaYaYExo8SHZMWF5X4d4mh57DJOsIXMjIWNKpf9/0hpxRu+P8p77YtXOOeRS
3xFdq7cOK1yQ8h/iRoXyLaxAFgWvVH+Ttmx4DLr+NsyzEQBjADeBCcF4YR9OZ7fj
ZYsoq68hH0W7zgZTSrCgvyGxdpu+UWWk/eV/foktxKBMV8K2GmAwyOlsAm59PgUI
EhfFH0WAEx6+jsMFLkn7USPWomFeyMwJJEiMKYCwALWIbNz1/1dtrZZs2QmBcjAu
AMFQhx8fykH4ON8a6fpS3TOEzf0HV1NX295O8wb8vS9B7w==
=aImY
-----END PGP PUBLIC KEY BLOCK-----

BIN
test/data/test.db Normal file

Binary file not shown.

1
test/data/test.kid Normal file
View file

@ -0,0 +1 @@
test_service_key

27
test/data/test.pem Normal file
View file

@ -0,0 +1,27 @@
-----BEGIN RSA PRIVATE KEY-----
MIIEowIBAAKCAQEAyqdQgnelhAPMSeyH0kr3UGePK9oFOmNfwD0Ymnh7YYXr21VH
WwyM2eVW3cnLd9KXywDFtGSe9oFDbnOuMCdUowdkBcaHju+isbv5KEbNSoy/T2Ri
p+6L0cY63YzcMJzv1nEYztYXS8wz76pSK81BKBCLapqOCmcPeCvV9yaoFZYvZEsX
Cl5jjXN3iujSzSF5Z6PpNFlJWTErMT2Z4QfbDKX2Nw6vJN6JnGpTNHZvgvcyNX8v
kSgVpQ8DFnFkBEx54PvRV5KpHAq6AsJxKONMo11idQS2PfCNpa2hvz9O6UZe+eIX
8jPo5NW8TuGZJumbdPT/nxTDLfCqfiZboeI0PwIDAQABAoIBAHVJhLUd3jObpx6Z
wLobHSvx49Dza9cxMHeoZJbyaCY3Rhw5LQUrLFHoA/B1HEeLIMMi/Um8eqwcgBRq
60N/X+LDIkadclNtqfHH4xpGcAZXk1m1tcuPqmiMnAEhx0ZzbfPknQEIs47w7pYl
M02ai71OZgIa1V57614XsMxMGTf0HadsmqC0cxLQ21UxROzCvv49N26ZTav7aLwl
1yW+scP/lo2HH6VJFTNJduOBgmnnMVIEhYLHa26lsf3biARf8TyV2xupex7s46iD
RegXZzzAlHx8/qkFoTfENNefAbBX87E+r1gs9zmWEo+2DaKYxmuDTqAbk107c1Jo
XQ59MRECgYEA0utdbhdgL0ll69tV4wGdZb49VTCzfmGrab8X+8Hy6aFzUDYqIsJD
1l5e7fPxY+MqeP4XoJ7q5RAqIBAXSubds743RElMxvLsiy6Q/HuZ5ICxjUZ2Amop
mItcGG9RXaiqXAKHk6ZMgFhO3/NAVv2it+XnP3uLucAgmqh7Wp7ML2MCgYEA9fet
kYirz32ZAvxWDrTapQjSDkyCAaCZGB+BQ5paBLeMzTIcVHiu2aggCJrc4YoqB91D
JHlynZhvxOK0m1KXHDnbPn9YqwsVZDTIU4PnpC0KEj357VujXDH/tD0ggzrm5ruQ
4o0SpfavI7MAe0vUlv46x+CfIzSq+kPRenrRBHUCgYEAyCAIk1fcrKFg8ow3jt/O
X2ZFPZqrBMRZZ0mo0PiyqljFWBs8maRnx3PdcLvgk11MxGabNozy5YsT3T5HS4uI
Wm6mc8V08uQ16s2xRc9lMnmlfh2YBSyD8ThxlsGwm0RY+FpyF3dX6QNhO37L0n5w
MTsT0pk/92xDw1sPR+maZW8CgYBp8GJ2k1oExUDZE1vxe53MhS8L75HzJ3uo8zDW
sC1jaLchThr7mvscThh1/FV0YvDVcExR8mkWTaieMVK+r2TcSGMQ2QKUsPJmtYEu
z1o+0RNMZhs2S0jiFbrfo5BUVVNMP68YlNBaYRRwGNH1SOTon9kra6i/HhkiL4GS
8kECXQKBgAs/DqfCobJsIMi7TDcG1FkQEKwPmnKmh8rDX3665gCRz7rOoIG8u05A
J6pQqrUrPRI+AAtVM4nW0z4KE07ruTJ/8wapTErm/5Bp5bikiaHy7NY2kHj3hVwr
KYh700ZUPV9vd+xUpfTNoVyvV2tu4QnG8ihKII6vfCPItEpE8glo
-----END RSA PRIVATE KEY-----

301
test/fixtures.py Normal file
View file

@ -0,0 +1,301 @@
import os
from cachetools.func import lru_cache
from collections import namedtuple
from datetime import datetime, timedelta
import pytest
import shutil
import inspect
from flask import Flask, jsonify
from flask_login import LoginManager
from flask_principal import identity_loaded, Permission, Identity, identity_changed, Principal
from flask_mail import Mail
from peewee import SqliteDatabase, InternalError
from mock import patch
from app import app as application
from auth.permissions import on_identity_loaded
from data import model
from data.database import close_db_filter, db, configure
from data.model.user import LoginWrappedDBUser, create_robot, lookup_robot, create_user_noverify
from data.model.repository import create_repository
from data.model.repo_mirror import enable_mirroring_for_repository
from data.userfiles import Userfiles
from endpoints.api import api_bp
from endpoints.appr import appr_bp
from endpoints.web import web
from endpoints.v1 import v1_bp
from endpoints.v2 import v2_bp
from endpoints.verbs import verbs as verbs_bp
from endpoints.webhooks import webhooks
from initdb import initialize_database, populate_database
from path_converters import APIRepositoryPathConverter, RegexConverter, RepositoryPathConverter
from test.testconfig import FakeTransaction
INIT_DB_PATH = 0
@pytest.fixture(scope="session")
def init_db_path(tmpdir_factory):
""" Creates a new database and appropriate configuration. Note that the initial database
is created *once* per session. In the non-full-db-test case, the database_uri fixture
makes a copy of the SQLite database file on disk and passes a new copy to each test.
"""
# NOTE: We use a global here because pytest runs this code multiple times, due to the fixture
# being imported instead of being in a conftest. Moving to conftest has its own issues, and this
# call is quite slow, so we simply cache it here.
global INIT_DB_PATH
INIT_DB_PATH = INIT_DB_PATH or _init_db_path(tmpdir_factory)
return INIT_DB_PATH
def _init_db_path(tmpdir_factory):
if os.environ.get('TEST_DATABASE_URI'):
return _init_db_path_real_db(os.environ.get('TEST_DATABASE_URI'))
return _init_db_path_sqlite(tmpdir_factory)
def _init_db_path_real_db(db_uri):
""" Initializes a real database for testing by populating it from scratch. Note that this does
*not* add the tables (merely data). Callers must have migrated the database before calling
the test suite.
"""
configure({
"DB_URI": db_uri,
"SECRET_KEY": "superdupersecret!!!1",
"DB_CONNECTION_ARGS": {
'threadlocals': True,
'autorollback': True,
},
"DB_TRANSACTION_FACTORY": _create_transaction,
"DATABASE_SECRET_KEY": "anothercrazykey!",
})
populate_database()
return db_uri
def _init_db_path_sqlite(tmpdir_factory):
""" Initializes a SQLite database for testing by populating it from scratch and placing it into
a temp directory file.
"""
sqlitedbfile = str(tmpdir_factory.mktemp("data").join("test.db"))
sqlitedb = 'sqlite:///{0}'.format(sqlitedbfile)
conf = {"TESTING": True,
"DEBUG": True,
"SECRET_KEY": "superdupersecret!!!1",
"DATABASE_SECRET_KEY": "anothercrazykey!",
"DB_URI": sqlitedb}
os.environ['DB_URI'] = str(sqlitedb)
db.initialize(SqliteDatabase(sqlitedbfile))
application.config.update(conf)
application.config.update({"DB_URI": sqlitedb})
initialize_database()
db.obj.execute_sql('PRAGMA foreign_keys = ON;')
db.obj.execute_sql('PRAGMA encoding="UTF-8";')
populate_database()
close_db_filter(None)
return str(sqlitedbfile)
@pytest.yield_fixture()
def database_uri(monkeypatch, init_db_path, sqlitedb_file):
""" Returns the database URI to use for testing. In the SQLite case, a new, distinct copy of
the SQLite database is created by copying the initialized database file (sqlitedb_file)
on a per-test basis. In the non-SQLite case, a reference to the existing database URI is
returned.
"""
if os.environ.get('TEST_DATABASE_URI'):
db_uri = os.environ['TEST_DATABASE_URI']
monkeypatch.setenv("DB_URI", db_uri)
yield db_uri
else:
# Copy the golden database file to a new path.
shutil.copy2(init_db_path, sqlitedb_file)
# Monkeypatch the DB_URI.
db_path = 'sqlite:///{0}'.format(sqlitedb_file)
monkeypatch.setenv("DB_URI", db_path)
yield db_path
# Delete the DB copy.
assert '..' not in sqlitedb_file
assert 'test.db' in sqlitedb_file
os.remove(sqlitedb_file)
@pytest.fixture()
def sqlitedb_file(tmpdir):
""" Returns the path at which the initialized, golden SQLite database file will be placed. """
test_db_file = tmpdir.mkdir("quaydb").join("test.db")
return str(test_db_file)
def _create_transaction(db):
return FakeTransaction()
@pytest.fixture()
def appconfig(database_uri):
""" Returns application configuration for testing that references the proper database URI. """
conf = {
"TESTING": True,
"DEBUG": True,
"DB_URI": database_uri,
"SECRET_KEY": 'superdupersecret!!!1',
"DATABASE_SECRET_KEY": "anothercrazykey!",
"DB_CONNECTION_ARGS": {
'threadlocals': True,
'autorollback': True,
},
"DB_TRANSACTION_FACTORY": _create_transaction,
"DATA_MODEL_CACHE_CONFIG": {
'engine': 'inmemory',
},
"USERFILES_PATH": "userfiles/",
"MAIL_SERVER": "",
"MAIL_DEFAULT_SENDER": "support@quay.io",
"DATABASE_SECRET_KEY": "anothercrazykey!",
}
return conf
AllowedAutoJoin = namedtuple('AllowedAutoJoin', ['frame_start_index', 'pattern_prefixes'])
ALLOWED_AUTO_JOINS = [
AllowedAutoJoin(0, ['test_']),
AllowedAutoJoin(0, ['<', 'test_']),
]
CALLER_FRAMES_OFFSET = 3
FRAME_NAME_INDEX = 3
@pytest.fixture()
def initialized_db(appconfig):
""" Configures the database for the database found in the appconfig. """
under_test_real_database = bool(os.environ.get('TEST_DATABASE_URI'))
# Configure the database.
configure(appconfig)
# Initialize caches.
model._basequery._lookup_team_roles()
model._basequery.get_public_repo_visibility()
model.log.get_log_entry_kinds()
if not under_test_real_database:
# Make absolutely sure foreign key constraints are on.
db.obj.execute_sql('PRAGMA foreign_keys = ON;')
db.obj.execute_sql('PRAGMA encoding="UTF-8";')
assert db.obj.execute_sql('PRAGMA foreign_keys;').fetchone()[0] == 1
assert db.obj.execute_sql('PRAGMA encoding;').fetchone()[0] == 'UTF-8'
# If under a test *real* database, setup a savepoint.
if under_test_real_database:
with db.transaction():
test_savepoint = db.savepoint()
test_savepoint.__enter__()
yield # Run the test.
try:
test_savepoint.rollback()
test_savepoint.__exit__(None, None, None)
except InternalError:
# If postgres fails with an exception (like IntegrityError) mid-transaction, it terminates
# it immediately, so when we go to remove the savepoint, it complains. We can safely ignore
# this case.
pass
else:
if os.environ.get('DISALLOW_AUTO_JOINS', 'false').lower() == 'true':
# Patch get_rel_instance to fail if we try to load any non-joined foreign key. This will allow
# us to catch missing joins when running tests.
def get_rel_instance(self, instance):
value = instance.__data__.get(self.name)
if value is not None or self.name in instance.__rel__:
if self.name not in instance.__rel__:
# NOTE: We only raise an exception if this auto-lookup occurs from non-testing code.
# Testing code can be a bit inefficient.
lookup_allowed = False
try:
outerframes = inspect.getouterframes(inspect.currentframe())
except IndexError:
# Happens due to a bug in Jinja.
outerframes = []
for allowed_auto_join in ALLOWED_AUTO_JOINS:
if lookup_allowed:
break
if len(outerframes) >= allowed_auto_join.frame_start_index + CALLER_FRAMES_OFFSET:
found_match = True
for index, pattern_prefix in enumerate(allowed_auto_join.pattern_prefixes):
frame_info = outerframes[index + CALLER_FRAMES_OFFSET]
if not frame_info[FRAME_NAME_INDEX].startswith(pattern_prefix):
found_match = False
break
if found_match:
lookup_allowed = True
break
if not lookup_allowed:
raise Exception('Missing join on instance `%s` for field `%s`', instance, self.name)
obj = self.rel_model.get(self.field.rel_field == value)
instance.__rel__[self.name] = obj
return instance.__rel__[self.name]
elif not self.field.null:
raise self.rel_model.DoesNotExist
return value
with patch('peewee.ForeignKeyAccessor.get_rel_instance', get_rel_instance):
yield
else:
yield
@pytest.fixture()
def app(appconfig, initialized_db):
""" Used by pytest-flask plugin to inject a custom app instance for testing. """
app = Flask(__name__)
login_manager = LoginManager(app)
@app.errorhandler(model.DataModelException)
def handle_dme(ex):
response = jsonify({'message': str(ex)})
response.status_code = 400
return response
@login_manager.user_loader
def load_user(user_uuid):
return LoginWrappedDBUser(user_uuid)
@identity_loaded.connect_via(app)
def on_identity_loaded_for_test(sender, identity):
on_identity_loaded(sender, identity)
Principal(app, use_sessions=False)
app.url_map.converters['regex'] = RegexConverter
app.url_map.converters['apirepopath'] = APIRepositoryPathConverter
app.url_map.converters['repopath'] = RepositoryPathConverter
app.register_blueprint(api_bp, url_prefix='/api')
app.register_blueprint(appr_bp, url_prefix='/cnr')
app.register_blueprint(web, url_prefix='/')
app.register_blueprint(verbs_bp, url_prefix='/c1')
app.register_blueprint(v1_bp, url_prefix='/v1')
app.register_blueprint(v2_bp, url_prefix='/v2')
app.register_blueprint(webhooks, url_prefix='/webhooks')
app.config.update(appconfig)
Userfiles(app)
Mail(app)
return app

70
test/fulldbtest.sh Executable file
View file

@ -0,0 +1,70 @@
set -e
up_mysql() {
# Run a SQL database on port 3306 inside of Docker.
docker run --name mysql -p 3306:3306 -e MYSQL_ROOT_PASSWORD=password -d mysql:5.7
# Sleep for 10s to get MySQL get started.
echo 'Sleeping for 10...'
sleep 10
# Add the database to mysql.
docker run --rm --link mysql:mysql mysql:5.7 sh -c 'echo "create database genschema;" | mysql -h"$MYSQL_PORT_3306_TCP_ADDR" -P"$MYSQL_PORT_3306_TCP_PORT" -uroot -ppassword'
}
down_mysql() {
docker kill mysql || true
docker rm -v mysql || true
}
up_postgres() {
# Run a SQL database on port 5432 inside of Docker.
docker run --name postgres -p 5432:5432 -d postgres
# Sleep for 10s to get SQL get started.
echo 'Sleeping for 10...'
sleep 10
# Add the database to postgres.
docker run --rm --link postgres:postgres postgres sh -c 'echo "create database genschema" | psql -h "$POSTGRES_PORT_5432_TCP_ADDR" -p "$POSTGRES_PORT_5432_TCP_PORT" -U postgres'
docker run --rm --link postgres:postgres postgres sh -c 'echo "CREATE EXTENSION IF NOT EXISTS pg_trgm;" | psql -h "$POSTGRES_PORT_5432_TCP_ADDR" -p "$POSTGRES_PORT_5432_TCP_PORT" -U postgres -d genschema'
}
down_postgres() {
docker kill postgres || true
docker rm -v postgres || true
}
run_tests() {
# Initialize the database with schema.
PYTHONPATH=. TEST_DATABASE_URI=$1 TEST=true alembic upgrade head
# Run the full test suite.
PYTHONPATH=. SKIP_DB_SCHEMA=true TEST_DATABASE_URI=$1 TEST=true py.test ${2:-.} --ignore=endpoints/appr/test/
}
CIP=${CONTAINERIP-'127.0.0.1'}
echo "> Using container IP address $CIP"
# NOTE: MySQL is currently broken on setup.
# Test (and generate, if requested) via MySQL.
echo '> Starting MySQL'
down_mysql
up_mysql
echo '> Running Full Test Suite (mysql)'
set +e
run_tests "mysql+pymysql://root:password@$CIP/genschema" $1
set -e
down_mysql
# Test via Postgres.
echo '> Starting Postgres'
down_postgres
up_postgres
echo '> Running Full Test Suite (postgres)'
set +e
run_tests "postgresql://postgres@$CIP/genschema" $1
set -e
down_postgres

83
test/helpers.py Normal file
View file

@ -0,0 +1,83 @@
import multiprocessing
import time
import socket
from contextlib import contextmanager
from data.database import LogEntryKind, LogEntry3
class assert_action_logged(object):
""" Specialized assertion for ensuring that a log entry of a particular kind was added under the
context of this call.
"""
def __init__(self, log_kind):
self.log_kind = log_kind
self.existing_count = 0
def _get_log_count(self):
return LogEntry3.select().where(LogEntry3.kind == LogEntryKind.get(name=self.log_kind)).count()
def __enter__(self):
self.existing_count = self._get_log_count()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_val is None:
updated_count = self._get_log_count()
error_msg = 'Missing new log entry of kind %s' % self.log_kind
assert self.existing_count == (updated_count - 1), error_msg
_LIVESERVER_TIMEOUT = 5
@contextmanager
def liveserver_app(flask_app, port):
"""
Based on https://github.com/jarus/flask-testing/blob/master/flask_testing/utils.py
Runs the given Flask app as a live web server locally, on the given port, starting it
when called and terminating after the yield.
Usage:
with liveserver_app(flask_app, port):
# Code that makes use of the app.
"""
shared = {}
def _can_ping_server():
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
sock.connect(('localhost', port))
except socket.error:
success = False
else:
success = True
finally:
sock.close()
return success
def _spawn_live_server():
worker = lambda app, port: app.run(port=port, use_reloader=False)
shared['process'] = multiprocessing.Process(target=worker, args=(flask_app, port))
shared['process'].start()
start_time = time.time()
while True:
elapsed_time = (time.time() - start_time)
if elapsed_time > _LIVESERVER_TIMEOUT:
_terminate_live_server()
raise RuntimeError("Failed to start the server after %d seconds. " % _LIVESERVER_TIMEOUT)
if _can_ping_server():
break
def _terminate_live_server():
if shared.get('process'):
shared.get('process').terminate()
shared.pop('process')
try:
_spawn_live_server()
yield
finally:
_terminate_live_server()

77
test/queue_threads.py Normal file
View file

@ -0,0 +1,77 @@
import unittest
import json
import time
from functools import wraps
from threading import Thread, Lock
from app import app
from data.queue import WorkQueue
from initdb import wipe_database, initialize_database, populate_database
QUEUE_NAME = 'testqueuename'
class AutoUpdatingQueue(object):
def __init__(self, queue_to_wrap):
self._queue = queue_to_wrap
def _wrapper(self, func):
@wraps(func)
def wrapper(*args, **kwargs):
to_return = func(*args, **kwargs)
self._queue.update_metrics()
return to_return
return wrapper
def __getattr__(self, attr_name):
method_or_attr = getattr(self._queue, attr_name)
if callable(method_or_attr):
return self._wrapper(method_or_attr)
else:
return method_or_attr
class QueueTestCase(unittest.TestCase):
TEST_MESSAGE_1 = json.dumps({'data': 1})
def setUp(self):
self.transaction_factory = app.config['DB_TRANSACTION_FACTORY']
self.queue = AutoUpdatingQueue(WorkQueue(QUEUE_NAME, self.transaction_factory))
wipe_database()
initialize_database()
populate_database()
class TestQueueThreads(QueueTestCase):
def test_queue_threads(self):
count = [20]
for i in range(count[0]):
self.queue.put([str(i)], self.TEST_MESSAGE_1)
lock = Lock()
def get(lock, count, queue):
item = queue.get()
if item is None:
return
self.assertEqual(self.TEST_MESSAGE_1, item.body)
with lock:
count[0] -= 1
threads = []
# The thread count needs to be a few times higher than the queue size
# count because some threads will get a None and thus won't decrement
# the counter.
for i in range(100):
t = Thread(target=get, args=(lock, count, self.queue))
threads.append(t)
for t in threads:
t.start()
for t in threads:
t.join()
self.assertEqual(count[0], 0)
if __name__ == '__main__':
unittest.main()

View file

286
test/registry/fixtures.py Normal file
View file

@ -0,0 +1,286 @@
import copy
import logging.config
import json
import os
import shutil
from tempfile import NamedTemporaryFile
import pytest
from Crypto import Random
from flask import jsonify, g
from flask_principal import Identity
from app import storage
from data.database import (close_db_filter, configure, DerivedStorageForImage, QueueItem, Image,
TagManifest, TagManifestToManifest, Manifest, ManifestLegacyImage,
ManifestBlob, NamespaceGeoRestriction, User)
from data import model
from data.registry_model import registry_model
from endpoints.csrf import generate_csrf_token
from util.log import logfile_path
from test.registry.liveserverfixture import LiveServerExecutor
@pytest.fixture()
def registry_server_executor(app):
def generate_csrf():
return generate_csrf_token()
def set_supports_direct_download(enabled):
storage.put_content(['local_us'], 'supports_direct_download', 'true' if enabled else 'false')
return 'OK'
def delete_image(image_id):
image = Image.get(docker_image_id=image_id)
image.docker_image_id = 'DELETED'
image.save()
return 'OK'
def get_storage_replication_entry(image_id):
image = Image.get(docker_image_id=image_id)
QueueItem.select().where(QueueItem.queue_name ** ('%' + image.storage.uuid + '%')).get()
return 'OK'
def set_feature(feature_name, value):
import features
from app import app
old_value = features._FEATURES[feature_name].value
features._FEATURES[feature_name].value = value
app.config['FEATURE_%s' % feature_name] = value
return jsonify({'old_value': old_value})
def set_config_key(config_key, value):
from app import app as current_app
old_value = app.config.get(config_key)
app.config[config_key] = value
current_app.config[config_key] = value
# Close any existing connection.
close_db_filter(None)
# Reload the database config.
configure(app.config)
return jsonify({'old_value': old_value})
def clear_derived_cache():
DerivedStorageForImage.delete().execute()
return 'OK'
def clear_uncompressed_size(image_id):
image = model.image.get_image_by_id('devtable', 'newrepo', image_id)
image.storage.uncompressed_size = None
image.storage.save()
return 'OK'
def add_token():
another_token = model.token.create_delegate_token('devtable', 'newrepo', 'my-new-token',
'write')
return model.token.get_full_token_string(another_token)
def break_database():
# Close any existing connection.
close_db_filter(None)
# Reload the database config with an invalid connection.
config = copy.copy(app.config)
config['DB_URI'] = 'sqlite:///not/a/valid/database'
configure(config)
return 'OK'
def reload_app(server_hostname):
# Close any existing connection.
close_db_filter(None)
# Reload the database config.
app.config['SERVER_HOSTNAME'] = server_hostname[len('http://'):]
configure(app.config)
# Reload random after the process split, as it cannot be used uninitialized across forks.
Random.atfork()
# Required for anonymous calls to not exception.
g.identity = Identity(None, 'none')
if os.environ.get('DEBUGLOG') == 'true':
logging.config.fileConfig(logfile_path(debug=True), disable_existing_loggers=False)
return 'OK'
def create_app_repository(namespace, name):
user = model.user.get_user(namespace)
model.repository.create_repository(namespace, name, user, repo_kind='application')
return 'OK'
def disable_namespace(namespace):
namespace_obj = model.user.get_namespace_user(namespace)
namespace_obj.enabled = False
namespace_obj.save()
return 'OK'
def delete_manifests():
ManifestLegacyImage.delete().execute()
ManifestBlob.delete().execute()
Manifest.delete().execute()
TagManifestToManifest.delete().execute()
TagManifest.delete().execute()
return 'OK'
def set_geo_block_for_namespace(namespace_name, iso_country_code):
NamespaceGeoRestriction.create(namespace=User.get(username=namespace_name),
description='',
unstructured_json={},
restricted_region_iso_code=iso_country_code)
return 'OK'
executor = LiveServerExecutor()
executor.register('generate_csrf', generate_csrf)
executor.register('set_supports_direct_download', set_supports_direct_download)
executor.register('delete_image', delete_image)
executor.register('get_storage_replication_entry', get_storage_replication_entry)
executor.register('set_feature', set_feature)
executor.register('set_config_key', set_config_key)
executor.register('clear_derived_cache', clear_derived_cache)
executor.register('clear_uncompressed_size', clear_uncompressed_size)
executor.register('add_token', add_token)
executor.register('break_database', break_database)
executor.register('reload_app', reload_app)
executor.register('create_app_repository', create_app_repository)
executor.register('disable_namespace', disable_namespace)
executor.register('delete_manifests', delete_manifests)
executor.register('set_geo_block_for_namespace', set_geo_block_for_namespace)
return executor
@pytest.fixture(params=['pre_oci_model', 'oci_model'])
def data_model(request):
return request.param
@pytest.fixture()
def liveserver_app(app, registry_server_executor, init_db_path, data_model):
# Change the data model being used.
registry_model.set_for_testing(data_model == 'oci_model')
registry_server_executor.apply_blueprint_to_app(app)
if os.environ.get('DEBUG', 'false').lower() == 'true':
app.config['DEBUG'] = True
# Copy the clean database to a new path. We cannot share the DB created by the
# normal app fixture, as it is already open in the local process.
local_db_file = NamedTemporaryFile(delete=True)
local_db_file.close()
shutil.copy2(init_db_path, local_db_file.name)
app.config['DB_URI'] = 'sqlite:///{0}'.format(local_db_file.name)
return app
@pytest.fixture()
def app_reloader(request, liveserver, registry_server_executor):
registry_server_executor.on(liveserver).reload_app(liveserver.url)
yield
class FeatureFlagValue(object):
""" Helper object which temporarily sets the value of a feature flag.
Usage:
with FeatureFlagValue('ANONYMOUS_ACCESS', False, registry_server_executor.on(liveserver)):
... Features.ANONYMOUS_ACCESS is False in this context ...
"""
def __init__(self, feature_flag, test_value, executor):
self.feature_flag = feature_flag
self.test_value = test_value
self.executor = executor
self.old_value = None
def __enter__(self):
result = self.executor.set_feature(self.feature_flag, self.test_value)
self.old_value = result.json()['old_value']
def __exit__(self, type, value, traceback):
self.executor.set_feature(self.feature_flag, self.old_value)
class ConfigChange(object):
""" Helper object which temporarily sets the value of a config key.
Usage:
with ConfigChange('SOMEKEY', 'value', registry_server_executor.on(liveserver)):
... app.config['SOMEKEY'] is 'value' in this context ...
"""
def __init__(self, config_key, test_value, executor, liveserver):
self.config_key = config_key
self.test_value = test_value
self.executor = executor
self.liveserver = liveserver
self.old_value = None
def __enter__(self):
result = self.executor.set_config_key(self.config_key, self.test_value)
self.old_value = result.json()['old_value']
def __exit__(self, type, value, traceback):
self.executor.set_config_key(self.config_key, self.old_value)
class ApiCaller(object):
def __init__(self, liveserver_session, registry_server_executor):
self.liveserver_session = liveserver_session
self.registry_server_executor = registry_server_executor
def conduct_auth(self, username, password):
r = self.post('/api/v1/signin',
data=json.dumps(dict(username=username, password=password)),
headers={'Content-Type': 'application/json'})
assert r.status_code == 200
def _adjust_params(self, kwargs):
csrf_token = self.registry_server_executor.on_session(self.liveserver_session).generate_csrf()
if 'params' not in kwargs:
kwargs['params'] = {}
kwargs['params'].update({
'_csrf_token': csrf_token,
})
return kwargs
def get(self, url, **kwargs):
kwargs = self._adjust_params(kwargs)
return self.liveserver_session.get(url, **kwargs)
def post(self, url, **kwargs):
kwargs = self._adjust_params(kwargs)
return self.liveserver_session.post(url, **kwargs)
def put(self, url, **kwargs):
kwargs = self._adjust_params(kwargs)
return self.liveserver_session.put(url, **kwargs)
def delete(self, url, **kwargs):
kwargs = self._adjust_params(kwargs)
return self.liveserver_session.delete(url, **kwargs)
def change_repo_visibility(self, namespace, repository, visibility):
self.post('/api/v1/repository/%s/%s/changevisibility' % (namespace, repository),
data=json.dumps(dict(visibility=visibility)),
headers={'Content-Type': 'application/json'})
@pytest.fixture(scope="function")
def api_caller(liveserver, registry_server_executor):
return ApiCaller(liveserver.new_session(), registry_server_executor)

View file

@ -0,0 +1,283 @@
import inspect
import json
import multiprocessing
import socket
import socketserver
import time
from contextlib import contextmanager
from urlparse import urlparse, urljoin
import pytest
import requests
from flask import request
from flask.blueprints import Blueprint
class liveFlaskServer(object):
""" Helper class for spawning a live Flask server for testing.
Based on https://github.com/jarus/flask-testing/blob/master/flask_testing/utils.py#L421
"""
def __init__(self, app, port_value):
self.app = app
self._port_value = port_value
self._process = None
def get_server_url(self):
"""
Return the url of the test server
"""
return 'http://localhost:%s' % self._port_value.value
def terminate_live_server(self):
if self._process:
self._process.terminate()
def spawn_live_server(self):
self._process = None
port_value = self._port_value
def worker(app, port):
# Based on solution: http://stackoverflow.com/a/27598916
# Monkey-patch the server_bind so we can determine the port bound by Flask.
# This handles the case where the port specified is `0`, which means that
# the OS chooses the port. This is the only known way (currently) of getting
# the port out of Flask once we call `run`.
original_socket_bind = socketserver.TCPServer.server_bind
def socket_bind_wrapper(self):
ret = original_socket_bind(self)
# Get the port and save it into the port_value, so the parent process
# can read it.
(_, port) = self.socket.getsockname()
port_value.value = port
socketserver.TCPServer.server_bind = original_socket_bind
return ret
socketserver.TCPServer.server_bind = socket_bind_wrapper
app.run(port=port, use_reloader=False)
retry_count = self.app.config.get('LIVESERVER_RETRY_COUNT', 3)
started = False
for _ in range(0, retry_count):
if started:
break
self._process = multiprocessing.Process(target=worker, args=(self.app, 0))
self._process.start()
# We must wait for the server to start listening, but give up
# after a specified maximum timeout
timeout = self.app.config.get('LIVESERVER_TIMEOUT', 10)
start_time = time.time()
while True:
time.sleep(0.1)
elapsed_time = (time.time() - start_time)
if elapsed_time > timeout:
break
if self._can_connect():
self.app.config['SERVER_HOSTNAME'] = 'localhost:%s' % self._port_value.value
started = True
break
if not started:
raise RuntimeError("Failed to start the server after %d retries. " % retry_count)
def _can_connect(self):
host, port = self._get_server_address()
if port == 0:
# Port specified by the user was 0, and the OS has not yet assigned
# the proper port.
return False
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
sock.connect((host, port))
except socket.error:
success = False
else:
success = True
finally:
sock.close()
return success
def _get_server_address(self):
"""
Gets the server address used to test the connection with a socket.
Respects both the LIVESERVER_PORT config value and overriding
get_server_url()
"""
parts = urlparse(self.get_server_url())
host = parts.hostname
port = parts.port
if port is None:
if parts.scheme == 'http':
port = 80
elif parts.scheme == 'https':
port = 443
else:
raise RuntimeError("Unsupported server url scheme: %s" % parts.scheme)
return host, port
class LiveFixtureServerSession(object):
""" Helper class for calling the live server via a single requests Session. """
def __init__(self, base_url):
self.base_url = base_url
self.session = requests.Session()
def _get_url(self, url):
return urljoin(self.base_url, url)
def get(self, url, **kwargs):
return self.session.get(self._get_url(url), **kwargs)
def post(self, url, **kwargs):
return self.session.post(self._get_url(url), **kwargs)
def put(self, url, **kwargs):
return self.session.put(self._get_url(url), **kwargs)
def delete(self, url, **kwargs):
return self.session.delete(self._get_url(url), **kwargs)
def request(self, method, url, **kwargs):
return self.session.request(method, self._get_url(url), **kwargs)
class LiveFixtureServer(object):
""" Helper for interacting with a live server. """
def __init__(self, url):
self.url = url
@contextmanager
def session(self):
""" Yields a session for speaking to the live server. """
yield LiveFixtureServerSession(self.url)
def new_session(self):
""" Returns a new session for speaking to the live server. """
return LiveFixtureServerSession(self.url)
@pytest.fixture(scope='function')
def liveserver(liveserver_app):
""" Runs a live Flask server for the app for the duration of the test.
Based on https://github.com/jarus/flask-testing/blob/master/flask_testing/utils.py#L421
"""
context = liveserver_app.test_request_context()
context.push()
port = multiprocessing.Value('i', 0)
live_server = liveFlaskServer(liveserver_app, port)
try:
live_server.spawn_live_server()
yield LiveFixtureServer(live_server.get_server_url())
finally:
context.pop()
live_server.terminate_live_server()
@pytest.fixture(scope='function')
def liveserver_session(liveserver, liveserver_app):
""" Fixtures which instantiates a liveserver and returns a single session for
interacting with that server.
"""
return LiveFixtureServerSession(liveserver.url)
class LiveServerExecutor(object):
""" Helper class which can be used to register functions to be executed in the
same process as the live server. This is necessary because the live server
runs in a different process and, therefore, in order to execute state changes
outside of the server's normal flows (i.e. via code), it must be executed
*in-process* via an HTTP call. The LiveServerExecutor class abstracts away
all the setup for this process.
Usage:
def _perform_operation(first_param, second_param):
... do some operation in the app ...
return 'some value'
@pytest.fixture(scope="session")
def my_server_executor():
executor = LiveServerExecutor()
executor.register('performoperation', _perform_operation)
return executor
@pytest.fixture()
def liveserver_app(app, my_server_executor):
... other app setup here ...
my_server_executor.apply_blueprint_to_app(app)
return app
def test_mytest(liveserver, my_server_executor):
# Invokes 'performoperation' in the liveserver's process.
my_server_executor.on(liveserver).performoperation('first', 'second')
"""
def __init__(self):
self.funcs = {}
def register(self, fn_name, fn):
""" Registers the given function under the given name. """
self.funcs[fn_name] = fn
def apply_blueprint_to_app(self, app):
""" Applies a blueprint to the app, to support invocation from this executor. """
testbp = Blueprint('testbp', __name__)
def build_invoker(fn_name, fn):
path = '/' + fn_name
@testbp.route(path, methods=['POST'], endpoint=fn_name)
def _(**kwargs):
arg_values = request.get_json()['args']
return fn(*arg_values)
for fn_name, fn in self.funcs.iteritems():
build_invoker(fn_name, fn)
app.register_blueprint(testbp, url_prefix='/__test')
def on(self, server):
""" Returns an invoker for the given live server. """
return liveServerExecutorInvoker(self.funcs, server)
def on_session(self, server_session):
""" Returns an invoker for the given live server session. """
return liveServerExecutorInvoker(self.funcs, server_session)
class liveServerExecutorInvoker(object):
def __init__(self, funcs, server_or_session):
self._funcs = funcs
self._server_or_session = server_or_session
def __getattribute__(self, name):
if name.startswith('_'):
return object.__getattribute__(self, name)
if name not in self._funcs:
raise AttributeError('Unknown function: %s' % name)
def invoker(*args):
path = '/__test/%s' % name
headers = {'Content-Type': 'application/json'}
if isinstance(self._server_or_session, LiveFixtureServerSession):
return self._server_or_session.post(path, data=json.dumps({'args': args}), headers=headers)
else:
with self._server_or_session.session() as session:
return session.post(path, data=json.dumps({'args': args}), headers=headers)
return invoker

View file

@ -0,0 +1,228 @@
# -*- coding: utf-8 -*-
import random
import string
import pytest
from Crypto.PublicKey import RSA
from jwkest.jwk import RSAKey
from test.registry.fixtures import data_model
from test.registry.protocols import Image, layer_bytes_for_contents
from test.registry.protocol_v1 import V1Protocol
from test.registry.protocol_v2 import V2Protocol
@pytest.fixture(scope="session")
def basic_images():
""" Returns basic images for push and pull testing. """
# Note: order is from base layer down to leaf.
parent_bytes = layer_bytes_for_contents('parent contents')
image_bytes = layer_bytes_for_contents('some contents')
return [
Image(id='parentid', bytes=parent_bytes, parent_id=None),
Image(id='someid', bytes=image_bytes, parent_id='parentid'),
]
@pytest.fixture(scope="session")
def unicode_images():
""" Returns basic images for push and pull testing that contain unicode in the image metadata. """
# Note: order is from base layer down to leaf.
parent_bytes = layer_bytes_for_contents('parent contents')
image_bytes = layer_bytes_for_contents('some contents')
return [
Image(id='parentid', bytes=parent_bytes, parent_id=None),
Image(id='someid', bytes=image_bytes, parent_id='parentid',
config={'comment': u'the Pawe\xc5\x82 Kami\xc5\x84ski image',
'author': u'Sômé guy'}),
]
@pytest.fixture(scope="session")
def different_images():
""" Returns different basic images for push and pull testing. """
# Note: order is from base layer down to leaf.
parent_bytes = layer_bytes_for_contents('different parent contents')
image_bytes = layer_bytes_for_contents('some different contents')
return [
Image(id='anotherparentid', bytes=parent_bytes, parent_id=None),
Image(id='anothersomeid', bytes=image_bytes, parent_id='anotherparentid'),
]
@pytest.fixture(scope="session")
def sized_images():
""" Returns basic images (with sizes) for push and pull testing. """
# Note: order is from base layer down to leaf.
parent_bytes = layer_bytes_for_contents('parent contents', mode='')
image_bytes = layer_bytes_for_contents('some contents', mode='')
return [
Image(id='parentid', bytes=parent_bytes, parent_id=None, size=len(parent_bytes),
config={'foo': 'bar'}),
Image(id='someid', bytes=image_bytes, parent_id='parentid', size=len(image_bytes),
config={'foo': 'childbar', 'Entrypoint': ['hello']},
created='2018-04-03T18:37:09.284840891Z'),
]
@pytest.fixture(scope="session")
def multi_layer_images():
""" Returns complex images (with sizes) for push and pull testing. """
# Note: order is from base layer down to leaf.
layer1_bytes = layer_bytes_for_contents('layer 1 contents', mode='', other_files={
'file1': 'from-layer-1',
})
layer2_bytes = layer_bytes_for_contents('layer 2 contents', mode='', other_files={
'file2': 'from-layer-2',
})
layer3_bytes = layer_bytes_for_contents('layer 3 contents', mode='', other_files={
'file1': 'from-layer-3',
'file3': 'from-layer-3',
})
layer4_bytes = layer_bytes_for_contents('layer 4 contents', mode='', other_files={
'file3': 'from-layer-4',
})
layer5_bytes = layer_bytes_for_contents('layer 5 contents', mode='', other_files={
'file4': 'from-layer-5',
})
return [
Image(id='layer1', bytes=layer1_bytes, parent_id=None, size=len(layer1_bytes),
config={'internal_id': 'layer1'}),
Image(id='layer2', bytes=layer2_bytes, parent_id='layer1', size=len(layer2_bytes),
config={'internal_id': 'layer2'}),
Image(id='layer3', bytes=layer3_bytes, parent_id='layer2', size=len(layer3_bytes),
config={'internal_id': 'layer3'}),
Image(id='layer4', bytes=layer4_bytes, parent_id='layer3', size=len(layer4_bytes),
config={'internal_id': 'layer4'}),
Image(id='someid', bytes=layer5_bytes, parent_id='layer4', size=len(layer5_bytes),
config={'internal_id': 'layer5'}),
]
@pytest.fixture(scope="session")
def remote_images():
""" Returns images with at least one remote layer for push and pull testing. """
# Note: order is from base layer down to leaf.
remote_bytes = layer_bytes_for_contents('remote contents')
parent_bytes = layer_bytes_for_contents('parent contents')
image_bytes = layer_bytes_for_contents('some contents')
return [
Image(id='remoteid', bytes=remote_bytes, parent_id=None, urls=['http://some/url']),
Image(id='parentid', bytes=parent_bytes, parent_id='remoteid'),
Image(id='someid', bytes=image_bytes, parent_id='parentid'),
]
@pytest.fixture(scope="session")
def images_with_empty_layer():
""" Returns images for push and pull testing that contain an empty layer. """
# Note: order is from base layer down to leaf.
parent_bytes = layer_bytes_for_contents('parent contents')
empty_bytes = layer_bytes_for_contents('', empty=True)
image_bytes = layer_bytes_for_contents('some contents')
middle_bytes = layer_bytes_for_contents('middle')
return [
Image(id='parentid', bytes=parent_bytes, parent_id=None),
Image(id='emptyid', bytes=empty_bytes, parent_id='parentid', is_empty=True),
Image(id='middleid', bytes=middle_bytes, parent_id='emptyid'),
Image(id='emptyid2', bytes=empty_bytes, parent_id='middleid', is_empty=True),
Image(id='someid', bytes=image_bytes, parent_id='emptyid2'),
]
@pytest.fixture(scope="session")
def unicode_emoji_images():
""" Returns basic images for push and pull testing that contain unicode in the image metadata. """
# Note: order is from base layer down to leaf.
parent_bytes = layer_bytes_for_contents('parent contents')
image_bytes = layer_bytes_for_contents('some contents')
return [
Image(id='parentid', bytes=parent_bytes, parent_id=None),
Image(id='someid', bytes=image_bytes, parent_id='parentid',
config={'comment': u'😱',
'author': u'Sômé guy'}),
]
@pytest.fixture(scope="session")
def jwk():
return RSAKey(key=RSA.generate(2048))
@pytest.fixture(params=[V2Protocol])
def v2_protocol(request, jwk):
return request.param(jwk)
@pytest.fixture()
def v22_protocol(request, jwk):
return V2Protocol(jwk, schema2=True)
@pytest.fixture(params=[V1Protocol])
def v1_protocol(request, jwk):
return request.param(jwk)
@pytest.fixture(params=['schema1', 'schema2'])
def manifest_protocol(request, data_model, jwk):
return V2Protocol(jwk, schema2=(request == 'schema2' and data_model == 'oci_model'))
@pytest.fixture(params=['v1', 'v2_1', 'v2_2'])
def pusher(request, data_model, jwk):
if request.param == 'v1':
return V1Protocol(jwk)
if request.param == 'v2_2' and data_model == 'oci_model':
return V2Protocol(jwk, schema2=True)
return V2Protocol(jwk)
@pytest.fixture(params=['v1', 'v2_1'])
def legacy_puller(request, data_model, jwk):
if request.param == 'v1':
return V1Protocol(jwk)
return V2Protocol(jwk)
@pytest.fixture(params=['v1', 'v2_1'])
def legacy_pusher(request, data_model, jwk):
if request.param == 'v1':
return V1Protocol(jwk)
return V2Protocol(jwk)
@pytest.fixture(params=['v1', 'v2_1', 'v2_2'])
def puller(request, data_model, jwk):
if request.param == 'v1':
return V1Protocol(jwk)
if request.param == 'v2_2' and data_model == 'oci_model':
return V2Protocol(jwk, schema2=True)
return V2Protocol(jwk)
@pytest.fixture(params=[V1Protocol, V2Protocol])
def loginer(request, jwk):
return request.param(jwk)
@pytest.fixture(scope="session")
def random_layer_data():
size = 4096
contents = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(size))
return layer_bytes_for_contents(contents)

View file

@ -0,0 +1,263 @@
import json
from cStringIO import StringIO
from enum import Enum, unique
from digest.checksums import compute_simple, compute_tarsum
from test.registry.protocols import (RegistryProtocol, Failures, ProtocolOptions, PushResult,
PullResult)
@unique
class V1ProtocolSteps(Enum):
""" Defines the various steps of the protocol, for matching failures. """
PUT_IMAGES = 'put-images'
GET_IMAGES = 'get-images'
PUT_TAG = 'put-tag'
PUT_IMAGE_JSON = 'put-image-json'
DELETE_TAG = 'delete-tag'
GET_TAG = 'get-tag'
GET_LAYER = 'get-layer'
class V1Protocol(RegistryProtocol):
FAILURE_CODES = {
V1ProtocolSteps.PUT_IMAGES: {
Failures.INVALID_AUTHENTICATION: 403,
Failures.UNAUTHENTICATED: 401,
Failures.UNAUTHORIZED: 403,
Failures.APP_REPOSITORY: 405,
Failures.SLASH_REPOSITORY: 404,
Failures.INVALID_REPOSITORY: 400,
Failures.DISALLOWED_LIBRARY_NAMESPACE: 400,
Failures.NAMESPACE_DISABLED: 400,
Failures.READ_ONLY: 405,
Failures.MIRROR_ONLY: 405,
Failures.MIRROR_MISCONFIGURED: 500,
Failures.MIRROR_ROBOT_MISSING: 400,
Failures.READONLY_REGISTRY: 405,
},
V1ProtocolSteps.GET_IMAGES: {
Failures.INVALID_AUTHENTICATION: 403,
Failures.UNAUTHENTICATED: 403,
Failures.UNAUTHORIZED: 403,
Failures.APP_REPOSITORY: 404,
Failures.ANONYMOUS_NOT_ALLOWED: 401,
Failures.DISALLOWED_LIBRARY_NAMESPACE: 400,
Failures.NAMESPACE_DISABLED: 400,
},
V1ProtocolSteps.PUT_IMAGE_JSON: {
Failures.INVALID_IMAGES: 400,
Failures.READ_ONLY: 405,
Failures.MIRROR_ONLY: 405,
Failures.MIRROR_MISCONFIGURED: 500,
Failures.MIRROR_ROBOT_MISSING: 400,
Failures.READONLY_REGISTRY: 405,
},
V1ProtocolSteps.PUT_TAG: {
Failures.MISSING_TAG: 404,
Failures.INVALID_TAG: 400,
Failures.INVALID_IMAGES: 400,
Failures.NAMESPACE_DISABLED: 400,
Failures.READ_ONLY: 405,
Failures.MIRROR_ONLY: 405,
Failures.MIRROR_MISCONFIGURED: 500,
Failures.MIRROR_ROBOT_MISSING: 400,
Failures.READONLY_REGISTRY: 405,
},
V1ProtocolSteps.GET_LAYER: {
Failures.GEO_BLOCKED: 403,
},
V1ProtocolSteps.GET_TAG: {
Failures.UNKNOWN_TAG: 404,
},
}
def __init__(self, jwk):
pass
def _auth_for_credentials(self, credentials):
if credentials is None:
return None
return credentials
def ping(self, session):
assert session.get('/v1/_ping').status_code == 200
def login(self, session, username, password, scopes, expect_success):
data = {
'username': username,
'password': password,
}
response = self.conduct(session, 'POST', '/v1/users/', json_data=data, expected_status=400)
assert (response.text == '"Username or email already exists"') == expect_success
def pull(self, session, namespace, repo_name, tag_names, images, credentials=None,
expected_failure=None, options=None):
options = options or ProtocolOptions()
auth = self._auth_for_credentials(credentials)
tag_names = [tag_names] if isinstance(tag_names, str) else tag_names
prefix = '/v1/repositories/%s/' % self.repo_name(namespace, repo_name)
# Ping!
self.ping(session)
# GET /v1/repositories/{namespace}/{repository}/images
headers = {'X-Docker-Token': 'true'}
result = self.conduct(session, 'GET', prefix + 'images', auth=auth, headers=headers,
expected_status=(200, expected_failure, V1ProtocolSteps.GET_IMAGES))
if result.status_code != 200:
return
headers = {}
if credentials is not None:
headers['Authorization'] = 'token ' + result.headers['www-authenticate']
else:
assert not 'www-authenticate' in result.headers
# GET /v1/repositories/{namespace}/{repository}/tags
image_ids = self.conduct(session, 'GET', prefix + 'tags', headers=headers).json()
for tag_name in tag_names:
# GET /v1/repositories/{namespace}/{repository}/tags/<tag_name>
image_id_data = self.conduct(session, 'GET', prefix + 'tags/' + tag_name,
headers=headers,
expected_status=(200, expected_failure,
V1ProtocolSteps.GET_TAG))
if tag_name not in image_ids:
assert expected_failure == Failures.UNKNOWN_TAG
return None
tag_image_id = image_ids[tag_name]
assert image_id_data.json() == tag_image_id
# Retrieve the ancestry of the tagged image.
image_prefix = '/v1/images/%s/' % tag_image_id
ancestors = self.conduct(session, 'GET', image_prefix + 'ancestry', headers=headers).json()
assert len(ancestors) == len(images)
for index, image_id in enumerate(reversed(ancestors)):
# /v1/images/{imageID}/{ancestry, json, layer}
image_prefix = '/v1/images/%s/' % image_id
self.conduct(session, 'GET', image_prefix + 'ancestry', headers=headers)
result = self.conduct(session, 'GET', image_prefix + 'json', headers=headers)
assert result.json()['id'] == image_id
# Ensure we can HEAD the image layer.
self.conduct(session, 'HEAD', image_prefix + 'layer', headers=headers)
# And retrieve the layer data.
result = self.conduct(session, 'GET', image_prefix + 'layer', headers=headers,
expected_status=(200, expected_failure, V1ProtocolSteps.GET_LAYER),
options=options)
if result.status_code == 200:
assert result.content == images[index].bytes
return PullResult(manifests=None, image_ids=image_ids)
def push(self, session, namespace, repo_name, tag_names, images, credentials=None,
expected_failure=None, options=None):
auth = self._auth_for_credentials(credentials)
tag_names = [tag_names] if isinstance(tag_names, str) else tag_names
# Ping!
self.ping(session)
# PUT /v1/repositories/{namespace}/{repository}/
result = self.conduct(session, 'PUT',
'/v1/repositories/%s/' % self.repo_name(namespace, repo_name),
expected_status=(201, expected_failure, V1ProtocolSteps.PUT_IMAGES),
json_data={},
auth=auth)
if result.status_code != 201:
return
headers = {}
headers['Authorization'] = 'token ' + result.headers['www-authenticate']
for image in images:
assert image.urls is None
# PUT /v1/images/{imageID}/json
image_json_data = {'id': image.id}
if image.size is not None:
image_json_data['Size'] = image.size
if image.parent_id is not None:
image_json_data['parent'] = image.parent_id
if image.config is not None:
image_json_data['config'] = image.config
if image.created is not None:
image_json_data['created'] = image.created
image_json = json.dumps(image_json_data)
response = self.conduct(session, 'PUT', '/v1/images/%s/json' % image.id,
data=image_json, headers=headers,
expected_status=(200, expected_failure,
V1ProtocolSteps.PUT_IMAGE_JSON))
if response.status_code != 200:
return
# PUT /v1/images/{imageID}/checksum (old style)
old_checksum = compute_tarsum(StringIO(image.bytes), image_json)
checksum_headers = {'X-Docker-Checksum': old_checksum}
checksum_headers.update(headers)
self.conduct(session, 'PUT', '/v1/images/%s/checksum' % image.id,
headers=checksum_headers)
# PUT /v1/images/{imageID}/layer
self.conduct(session, 'PUT', '/v1/images/%s/layer' % image.id,
data=StringIO(image.bytes), headers=headers)
# PUT /v1/images/{imageID}/checksum (new style)
checksum = compute_simple(StringIO(image.bytes), image_json)
checksum_headers = {'X-Docker-Checksum-Payload': checksum}
checksum_headers.update(headers)
self.conduct(session, 'PUT', '/v1/images/%s/checksum' % image.id,
headers=checksum_headers)
# PUT /v1/repositories/{namespace}/{repository}/tags/latest
for tag_name in tag_names:
self.conduct(session, 'PUT',
'/v1/repositories/%s/tags/%s' % (self.repo_name(namespace, repo_name), tag_name),
data='"%s"' % images[-1].id,
headers=headers,
expected_status=(200, expected_failure, V1ProtocolSteps.PUT_TAG))
# PUT /v1/repositories/{namespace}/{repository}/images
self.conduct(session, 'PUT',
'/v1/repositories/%s/images' % self.repo_name(namespace, repo_name),
expected_status=204, headers=headers)
return PushResult(manifests=None, headers=headers)
def delete(self, session, namespace, repo_name, tag_names, credentials=None,
expected_failure=None, options=None):
auth = self._auth_for_credentials(credentials)
tag_names = [tag_names] if isinstance(tag_names, str) else tag_names
# Ping!
self.ping(session)
for tag_name in tag_names:
# DELETE /v1/repositories/{namespace}/{repository}/tags/{tag}
self.conduct(session, 'DELETE',
'/v1/repositories/%s/tags/%s' % (self.repo_name(namespace, repo_name), tag_name),
auth=auth,
expected_status=(200, expected_failure, V1ProtocolSteps.DELETE_TAG))
def tag(self, session, namespace, repo_name, tag_name, image, credentials=None,
expected_failure=None, options=None):
auth = self._auth_for_credentials(credentials)
self.conduct(session, 'PUT',
'/v1/repositories/%s/tags/%s' % (self.repo_name(namespace, repo_name), tag_name),
data='"%s"' % image.id,
auth=auth,
expected_status=(200, expected_failure, V1ProtocolSteps.PUT_TAG))

View file

@ -0,0 +1,705 @@
import hashlib
import json
from enum import Enum, unique
from image.docker.schema1 import (DockerSchema1ManifestBuilder, DockerSchema1Manifest,
DOCKER_SCHEMA1_CONTENT_TYPES)
from image.docker.schema2 import DOCKER_SCHEMA2_CONTENT_TYPES
from image.docker.schema2.manifest import DockerSchema2ManifestBuilder
from image.docker.schema2.config import DockerSchema2Config
from image.docker.schema2.list import DOCKER_SCHEMA2_MANIFESTLIST_CONTENT_TYPE
from image.docker.schemas import parse_manifest_from_bytes
from test.registry.protocols import (RegistryProtocol, Failures, ProtocolOptions, PushResult,
PullResult)
from util.bytes import Bytes
@unique
class V2ProtocolSteps(Enum):
""" Defines the various steps of the protocol, for matching failures. """
AUTH = 'auth'
BLOB_HEAD_CHECK = 'blob-head-check'
GET_MANIFEST = 'get-manifest'
GET_MANIFEST_LIST = 'get-manifest-list'
PUT_MANIFEST = 'put-manifest'
PUT_MANIFEST_LIST = 'put-manifest-list'
MOUNT_BLOB = 'mount-blob'
CATALOG = 'catalog'
LIST_TAGS = 'list-tags'
START_UPLOAD = 'start-upload'
GET_BLOB = 'get-blob'
class V2Protocol(RegistryProtocol):
FAILURE_CODES = {
V2ProtocolSteps.AUTH: {
Failures.UNAUTHENTICATED: 401,
Failures.INVALID_AUTHENTICATION: 401,
Failures.INVALID_REGISTRY: 400,
Failures.APP_REPOSITORY: 405,
Failures.ANONYMOUS_NOT_ALLOWED: 401,
Failures.INVALID_REPOSITORY: 400,
Failures.SLASH_REPOSITORY: 400,
Failures.NAMESPACE_DISABLED: 405,
},
V2ProtocolSteps.MOUNT_BLOB: {
Failures.UNAUTHORIZED_FOR_MOUNT: 202,
Failures.READONLY_REGISTRY: 405,
},
V2ProtocolSteps.GET_MANIFEST: {
Failures.UNKNOWN_TAG: 404,
Failures.UNAUTHORIZED: 401,
Failures.DISALLOWED_LIBRARY_NAMESPACE: 400,
Failures.ANONYMOUS_NOT_ALLOWED: 401,
},
V2ProtocolSteps.GET_BLOB: {
Failures.GEO_BLOCKED: 403,
},
V2ProtocolSteps.BLOB_HEAD_CHECK: {
Failures.DISALLOWED_LIBRARY_NAMESPACE: 400,
},
V2ProtocolSteps.START_UPLOAD: {
Failures.DISALLOWED_LIBRARY_NAMESPACE: 400,
Failures.READ_ONLY: 401,
Failures.MIRROR_ONLY: 401,
Failures.MIRROR_MISCONFIGURED: 401,
Failures.MIRROR_ROBOT_MISSING: 401,
Failures.READ_ONLY: 401,
Failures.READONLY_REGISTRY: 405,
},
V2ProtocolSteps.PUT_MANIFEST: {
Failures.DISALLOWED_LIBRARY_NAMESPACE: 400,
Failures.MISSING_TAG: 404,
Failures.INVALID_TAG: 404,
Failures.INVALID_IMAGES: 400,
Failures.INVALID_BLOB: 400,
Failures.UNSUPPORTED_CONTENT_TYPE: 415,
Failures.READ_ONLY: 401,
Failures.MIRROR_ONLY: 401,
Failures.MIRROR_MISCONFIGURED: 401,
Failures.MIRROR_ROBOT_MISSING: 401,
Failures.READONLY_REGISTRY: 405,
},
V2ProtocolSteps.PUT_MANIFEST_LIST: {
Failures.INVALID_MANIFEST: 400,
Failures.READ_ONLY: 401,
Failures.MIRROR_ONLY: 401,
Failures.MIRROR_MISCONFIGURED: 401,
Failures.MIRROR_ROBOT_MISSING: 401,
Failures.READONLY_REGISTRY: 405,
}
}
def __init__(self, jwk, schema2=False):
self.jwk = jwk
self.schema2 = schema2
def ping(self, session):
result = session.get('/v2/')
assert result.status_code == 401
assert result.headers['Docker-Distribution-API-Version'] == 'registry/2.0'
def login(self, session, username, password, scopes, expect_success):
scopes = scopes if isinstance(scopes, list) else [scopes]
params = {
'account': username,
'service': 'localhost:5000',
'scope': scopes,
}
auth = (username, password)
if not username or not password:
auth = None
response = session.get('/v2/auth', params=params, auth=auth)
if expect_success:
assert response.status_code / 100 == 2
else:
assert response.status_code / 100 == 4
return response
def auth(self, session, credentials, namespace, repo_name, scopes=None,
expected_failure=None):
"""
Performs the V2 Auth flow, returning the token (if any) and the response.
Spec: https://docs.docker.com/registry/spec/auth/token/
"""
scopes = scopes or []
auth = None
username = None
if credentials is not None:
username, _ = credentials
auth = credentials
params = {
'account': username,
'service': 'localhost:5000',
}
if scopes:
params['scope'] = scopes
response = self.conduct(session, 'GET', '/v2/auth', params=params, auth=auth,
expected_status=(200, expected_failure, V2ProtocolSteps.AUTH))
expect_token = (expected_failure is None or
not V2Protocol.FAILURE_CODES[V2ProtocolSteps.AUTH].get(expected_failure))
if expect_token:
assert response.json().get('token') is not None
return response.json().get('token'), response
return None, response
def pull_list(self, session, namespace, repo_name, tag_names, manifestlist,
credentials=None, expected_failure=None, options=None):
options = options or ProtocolOptions()
scopes = options.scopes or ['repository:%s:push,pull' % self.repo_name(namespace, repo_name)]
tag_names = [tag_names] if isinstance(tag_names, str) else tag_names
# Ping!
self.ping(session)
# Perform auth and retrieve a token.
token, _ = self.auth(session, credentials, namespace, repo_name, scopes=scopes,
expected_failure=expected_failure)
if token is None:
assert V2Protocol.FAILURE_CODES[V2ProtocolSteps.AUTH].get(expected_failure)
return
headers = {
'Authorization': 'Bearer ' + token,
'Accept': ','.join(DOCKER_SCHEMA2_CONTENT_TYPES),
}
for tag_name in tag_names:
# Retrieve the manifest for the tag or digest.
response = self.conduct(session, 'GET',
'/v2/%s/manifests/%s' % (self.repo_name(namespace, repo_name),
tag_name),
expected_status=(200, expected_failure,
V2ProtocolSteps.GET_MANIFEST_LIST),
headers=headers)
if expected_failure is not None:
return None
# Parse the returned manifest list and ensure it matches.
ct = response.headers['Content-Type']
assert ct == DOCKER_SCHEMA2_MANIFESTLIST_CONTENT_TYPE
retrieved = parse_manifest_from_bytes(Bytes.for_string_or_unicode(response.text), ct)
assert retrieved.schema_version == 2
assert retrieved.is_manifest_list
assert retrieved.digest == manifestlist.digest
# Pull each of the manifests inside and ensure they can be retrieved.
for manifest_digest in retrieved.child_manifest_digests():
response = self.conduct(session, 'GET',
'/v2/%s/manifests/%s' % (self.repo_name(namespace, repo_name),
manifest_digest),
expected_status=(200, expected_failure,
V2ProtocolSteps.GET_MANIFEST),
headers=headers)
if expected_failure is not None:
return None
ct = response.headers['Content-Type']
manifest = parse_manifest_from_bytes(Bytes.for_string_or_unicode(response.text), ct)
assert not manifest.is_manifest_list
assert manifest.digest == manifest_digest
def push_list(self, session, namespace, repo_name, tag_names, manifestlist, manifests, blobs,
credentials=None, expected_failure=None, options=None):
options = options or ProtocolOptions()
scopes = options.scopes or ['repository:%s:push,pull' % self.repo_name(namespace, repo_name)]
tag_names = [tag_names] if isinstance(tag_names, str) else tag_names
# Ping!
self.ping(session)
# Perform auth and retrieve a token.
token, _ = self.auth(session, credentials, namespace, repo_name, scopes=scopes,
expected_failure=expected_failure)
if token is None:
assert V2Protocol.FAILURE_CODES[V2ProtocolSteps.AUTH].get(expected_failure)
return
headers = {
'Authorization': 'Bearer ' + token,
'Accept': ','.join(options.accept_mimetypes) if options.accept_mimetypes is not None else '*/*',
}
# Push all blobs.
if not self._push_blobs(blobs, session, namespace, repo_name, headers, options,
expected_failure):
return
# Push the individual manifests.
for manifest in manifests:
manifest_headers = {'Content-Type': manifest.media_type}
manifest_headers.update(headers)
self.conduct(session, 'PUT',
'/v2/%s/manifests/%s' % (self.repo_name(namespace, repo_name), manifest.digest),
data=manifest.bytes.as_encoded_str(),
expected_status=(202, expected_failure, V2ProtocolSteps.PUT_MANIFEST),
headers=manifest_headers)
# Push the manifest list.
for tag_name in tag_names:
manifest_headers = {'Content-Type': manifestlist.media_type}
manifest_headers.update(headers)
if options.manifest_content_type is not None:
manifest_headers['Content-Type'] = options.manifest_content_type
self.conduct(session, 'PUT',
'/v2/%s/manifests/%s' % (self.repo_name(namespace, repo_name), tag_name),
data=manifestlist.bytes.as_encoded_str(),
expected_status=(202, expected_failure, V2ProtocolSteps.PUT_MANIFEST_LIST),
headers=manifest_headers)
return PushResult(manifests=None, headers=headers)
def build_schema2(self, images, blobs, options):
builder = DockerSchema2ManifestBuilder()
for image in images:
checksum = 'sha256:' + hashlib.sha256(image.bytes).hexdigest()
if image.urls is None:
blobs[checksum] = image.bytes
# If invalid blob references were requested, just make it up.
if options.manifest_invalid_blob_references:
checksum = 'sha256:' + hashlib.sha256('notarealthing').hexdigest()
if not image.is_empty:
builder.add_layer(checksum, len(image.bytes), urls=image.urls)
def history_for_image(image):
history = {
'created': '2018-04-03T18:37:09.284840891Z',
'created_by': (('/bin/sh -c #(nop) ENTRYPOINT %s' % image.config['Entrypoint'])
if image.config and image.config.get('Entrypoint')
else '/bin/sh -c #(nop) %s' % image.id),
}
if image.is_empty:
history['empty_layer'] = True
return history
config = {
"os": "linux",
"rootfs": {
"type": "layers",
"diff_ids": []
},
"history": [history_for_image(image) for image in images],
}
if images[-1].config:
config['config'] = images[-1].config
config_json = json.dumps(config, ensure_ascii=options.ensure_ascii)
schema2_config = DockerSchema2Config(Bytes.for_string_or_unicode(config_json))
builder.set_config(schema2_config)
blobs[schema2_config.digest] = schema2_config.bytes.as_encoded_str()
return builder.build(ensure_ascii=options.ensure_ascii)
def build_schema1(self, namespace, repo_name, tag_name, images, blobs, options, arch='amd64'):
builder = DockerSchema1ManifestBuilder(namespace, repo_name, tag_name, arch)
for image in reversed(images):
assert image.urls is None
checksum = 'sha256:' + hashlib.sha256(image.bytes).hexdigest()
blobs[checksum] = image.bytes
# If invalid blob references were requested, just make it up.
if options.manifest_invalid_blob_references:
checksum = 'sha256:' + hashlib.sha256('notarealthing').hexdigest()
layer_dict = {'id': image.id, 'parent': image.parent_id}
if image.config is not None:
layer_dict['config'] = image.config
if image.size is not None:
layer_dict['Size'] = image.size
if image.created is not None:
layer_dict['created'] = image.created
builder.add_layer(checksum, json.dumps(layer_dict, ensure_ascii=options.ensure_ascii))
# Build the manifest.
built = builder.build(self.jwk, ensure_ascii=options.ensure_ascii)
# Validate it before we send it.
DockerSchema1Manifest(built.bytes)
return built
def push(self, session, namespace, repo_name, tag_names, images, credentials=None,
expected_failure=None, options=None):
options = options or ProtocolOptions()
scopes = options.scopes or ['repository:%s:push,pull' % self.repo_name(namespace, repo_name)]
tag_names = [tag_names] if isinstance(tag_names, str) else tag_names
# Ping!
self.ping(session)
# Perform auth and retrieve a token.
token, _ = self.auth(session, credentials, namespace, repo_name, scopes=scopes,
expected_failure=expected_failure)
if token is None:
assert V2Protocol.FAILURE_CODES[V2ProtocolSteps.AUTH].get(expected_failure)
return
headers = {
'Authorization': 'Bearer ' + token,
'Accept': ','.join(options.accept_mimetypes) if options.accept_mimetypes is not None else '*/*',
}
# Build fake manifests.
manifests = {}
blobs = {}
for tag_name in tag_names:
if self.schema2:
manifests[tag_name] = self.build_schema2(images, blobs, options)
else:
manifests[tag_name] = self.build_schema1(namespace, repo_name, tag_name, images, blobs,
options)
# Push the blob data.
if not self._push_blobs(blobs, session, namespace, repo_name, headers, options,
expected_failure):
return
# Write a manifest for each tag.
for tag_name in tag_names:
manifest = manifests[tag_name]
# Write the manifest. If we expect it to be invalid, we expect a 404 code. Otherwise, we
# expect a 202 response for success.
put_code = 404 if options.manifest_invalid_blob_references else 202
manifest_headers = {'Content-Type': manifest.media_type}
manifest_headers.update(headers)
if options.manifest_content_type is not None:
manifest_headers['Content-Type'] = options.manifest_content_type
tag_or_digest = tag_name if not options.push_by_manifest_digest else manifest.digest
self.conduct(session, 'PUT',
'/v2/%s/manifests/%s' % (self.repo_name(namespace, repo_name), tag_or_digest),
data=manifest.bytes.as_encoded_str(),
expected_status=(put_code, expected_failure, V2ProtocolSteps.PUT_MANIFEST),
headers=manifest_headers)
return PushResult(manifests=manifests, headers=headers)
def _push_blobs(self, blobs, session, namespace, repo_name, headers, options, expected_failure):
for blob_digest, blob_bytes in blobs.iteritems():
if not options.skip_head_checks:
# Blob data should not yet exist.
self.conduct(session, 'HEAD',
'/v2/%s/blobs/%s' % (self.repo_name(namespace, repo_name), blob_digest),
expected_status=(404, expected_failure, V2ProtocolSteps.BLOB_HEAD_CHECK),
headers=headers)
# Check for mounting of blobs.
if options.mount_blobs and blob_digest in options.mount_blobs:
self.conduct(session, 'POST',
'/v2/%s/blobs/uploads/' % self.repo_name(namespace, repo_name),
params={
'mount': blob_digest,
'from': options.mount_blobs[blob_digest],
},
expected_status=(201, expected_failure, V2ProtocolSteps.MOUNT_BLOB),
headers=headers)
if expected_failure is not None:
return
else:
# Start a new upload of the blob data.
response = self.conduct(session, 'POST',
'/v2/%s/blobs/uploads/' % self.repo_name(namespace, repo_name),
expected_status=(202, expected_failure,
V2ProtocolSteps.START_UPLOAD),
headers=headers)
if response.status_code != 202:
continue
upload_uuid = response.headers['Docker-Upload-UUID']
new_upload_location = response.headers['Location']
assert new_upload_location.startswith('http://localhost:5000')
# We need to make this relative just for the tests because the live server test
# case modifies the port.
location = response.headers['Location'][len('http://localhost:5000'):]
# PATCH the data into the blob.
if options.chunks_for_upload is None:
self.conduct(session, 'PATCH', location, data=blob_bytes, expected_status=204,
headers=headers)
else:
# If chunked upload is requested, upload the data as a series of chunks, checking
# status at every point.
for chunk_data in options.chunks_for_upload:
if len(chunk_data) == 3:
(start_byte, end_byte, expected_code) = chunk_data
else:
(start_byte, end_byte) = chunk_data
expected_code = 204
patch_headers = {'Range': 'bytes=%s-%s' % (start_byte, end_byte)}
patch_headers.update(headers)
contents_chunk = blob_bytes[start_byte:end_byte]
self.conduct(session, 'PATCH', location, data=contents_chunk,
expected_status=expected_code,
headers=patch_headers)
if expected_code != 204:
return False
# Retrieve the upload status at each point, and ensure it is valid.
status_url = '/v2/%s/blobs/uploads/%s' % (self.repo_name(namespace, repo_name),
upload_uuid)
response = self.conduct(session, 'GET', status_url, expected_status=204,
headers=headers)
assert response.headers['Docker-Upload-UUID'] == upload_uuid
assert response.headers['Range'] == "bytes=0-%s" % end_byte
if options.cancel_blob_upload:
self.conduct(session, 'DELETE', location, params=dict(digest=blob_digest),
expected_status=204, headers=headers)
# Ensure the upload was canceled.
status_url = '/v2/%s/blobs/uploads/%s' % (self.repo_name(namespace, repo_name),
upload_uuid)
self.conduct(session, 'GET', status_url, expected_status=404, headers=headers)
return False
# Finish the blob upload with a PUT.
response = self.conduct(session, 'PUT', location, params=dict(digest=blob_digest),
expected_status=201, headers=headers)
assert response.headers['Docker-Content-Digest'] == blob_digest
# Ensure the blob exists now.
response = self.conduct(session, 'HEAD',
'/v2/%s/blobs/%s' % (self.repo_name(namespace, repo_name),
blob_digest),
expected_status=200, headers=headers)
assert response.headers['Docker-Content-Digest'] == blob_digest
assert response.headers['Content-Length'] == str(len(blob_bytes))
# And retrieve the blob data.
if not options.skip_blob_push_checks:
result = self.conduct(session, 'GET',
'/v2/%s/blobs/%s' % (self.repo_name(namespace, repo_name), blob_digest),
headers=headers, expected_status=200)
assert result.content == blob_bytes
return True
def delete(self, session, namespace, repo_name, tag_names, credentials=None,
expected_failure=None, options=None):
options = options or ProtocolOptions()
scopes = options.scopes or ['repository:%s:*' % self.repo_name(namespace, repo_name)]
tag_names = [tag_names] if isinstance(tag_names, str) else tag_names
# Ping!
self.ping(session)
# Perform auth and retrieve a token.
token, _ = self.auth(session, credentials, namespace, repo_name, scopes=scopes,
expected_failure=expected_failure)
if token is None:
return None
headers = {
'Authorization': 'Bearer ' + token,
}
for tag_name in tag_names:
self.conduct(session, 'DELETE',
'/v2/%s/manifests/%s' % (self.repo_name(namespace, repo_name), tag_name),
headers=headers,
expected_status=202)
def pull(self, session, namespace, repo_name, tag_names, images, credentials=None,
expected_failure=None, options=None):
options = options or ProtocolOptions()
scopes = options.scopes or ['repository:%s:pull' % self.repo_name(namespace, repo_name)]
tag_names = [tag_names] if isinstance(tag_names, str) else tag_names
# Ping!
self.ping(session)
# Perform auth and retrieve a token.
token, _ = self.auth(session, credentials, namespace, repo_name, scopes=scopes,
expected_failure=expected_failure)
if token is None and not options.attempt_pull_without_token:
return None
headers = {}
if token:
headers = {
'Authorization': 'Bearer ' + token,
}
if self.schema2:
headers['Accept'] = ','.join(options.accept_mimetypes
if options.accept_mimetypes is not None
else DOCKER_SCHEMA2_CONTENT_TYPES)
manifests = {}
image_ids = {}
for tag_name in tag_names:
# Retrieve the manifest for the tag or digest.
response = self.conduct(session, 'GET',
'/v2/%s/manifests/%s' % (self.repo_name(namespace, repo_name),
tag_name),
expected_status=(200, expected_failure, V2ProtocolSteps.GET_MANIFEST),
headers=headers)
if response.status_code == 401:
assert 'WWW-Authenticate' in response.headers
response.encoding = 'utf-8'
if expected_failure is not None:
return None
# Ensure the manifest returned by us is valid.
ct = response.headers['Content-Type']
if not self.schema2:
assert ct in DOCKER_SCHEMA1_CONTENT_TYPES
manifest = parse_manifest_from_bytes(Bytes.for_string_or_unicode(response.text), ct)
manifests[tag_name] = manifest
if manifest.schema_version == 1:
image_ids[tag_name] = manifest.leaf_layer_v1_image_id
# Verify the blobs.
layer_index = 0
empty_count = 0
blob_digests = list(manifest.blob_digests)
for image in images:
if manifest.schema_version == 2 and image.is_empty:
empty_count += 1
continue
# If the layer is remote, then we expect the blob to *not* exist in the system.
blob_digest = blob_digests[layer_index]
expected_status = 404 if image.urls else 200
result = self.conduct(session, 'GET',
'/v2/%s/blobs/%s' % (self.repo_name(namespace, repo_name),
blob_digest),
expected_status=(expected_status, expected_failure,
V2ProtocolSteps.GET_BLOB),
headers=headers,
options=options)
if expected_status == 200:
assert result.content == image.bytes
layer_index += 1
assert (len(blob_digests) + empty_count) >= len(images) # Schema 2 has 1 extra for config
return PullResult(manifests=manifests, image_ids=image_ids)
def tags(self, session, namespace, repo_name, page_size=2, credentials=None, options=None,
expected_failure=None):
options = options or ProtocolOptions()
scopes = options.scopes or ['repository:%s:pull' % self.repo_name(namespace, repo_name)]
# Ping!
self.ping(session)
# Perform auth and retrieve a token.
headers = {}
if credentials is not None:
token, _ = self.auth(session, credentials, namespace, repo_name, scopes=scopes,
expected_failure=expected_failure)
if token is None:
return None
headers = {
'Authorization': 'Bearer ' + token,
}
results = []
url = '/v2/%s/tags/list' % (self.repo_name(namespace, repo_name))
params = {}
if page_size is not None:
params['n'] = page_size
while True:
response = self.conduct(session, 'GET', url, headers=headers, params=params,
expected_status=(200, expected_failure, V2ProtocolSteps.LIST_TAGS))
data = response.json()
assert len(data['tags']) <= page_size
results.extend(data['tags'])
if not response.headers.get('Link'):
return results
link_url = response.headers['Link']
v2_index = link_url.find('/v2/')
url = link_url[v2_index:]
return results
def catalog(self, session, page_size=2, credentials=None, options=None, expected_failure=None,
namespace=None, repo_name=None, bearer_token=None):
options = options or ProtocolOptions()
scopes = options.scopes or []
# Ping!
self.ping(session)
# Perform auth and retrieve a token.
headers = {}
if credentials is not None:
token, _ = self.auth(session, credentials, namespace, repo_name, scopes=scopes,
expected_failure=expected_failure)
if token is None:
return None
headers = {
'Authorization': 'Bearer ' + token,
}
if bearer_token is not None:
headers = {
'Authorization': 'Bearer ' + bearer_token,
}
results = []
url = '/v2/_catalog'
params = {}
if page_size is not None:
params['n'] = page_size
while True:
response = self.conduct(session, 'GET', url, headers=headers, params=params,
expected_status=(200, expected_failure, V2ProtocolSteps.CATALOG))
data = response.json()
assert len(data['repositories']) <= page_size
results.extend(data['repositories'])
if not response.headers.get('Link'):
return results
link_url = response.headers['Link']
v2_index = link_url.find('/v2/')
url = link_url[v2_index:]
return results

148
test/registry/protocols.py Normal file
View file

@ -0,0 +1,148 @@
import json
import tarfile
from abc import ABCMeta, abstractmethod
from collections import namedtuple
from cStringIO import StringIO
from enum import Enum, unique
from six import add_metaclass
from image.docker.schema2 import EMPTY_LAYER_BYTES
Image = namedtuple('Image', ['id', 'parent_id', 'bytes', 'size', 'config', 'created', 'urls',
'is_empty'])
Image.__new__.__defaults__ = (None, None, None, None, False)
PushResult = namedtuple('PushResult', ['manifests', 'headers'])
PullResult = namedtuple('PullResult', ['manifests', 'image_ids'])
def layer_bytes_for_contents(contents, mode='|gz', other_files=None, empty=False):
if empty:
return EMPTY_LAYER_BYTES
layer_data = StringIO()
tar_file = tarfile.open(fileobj=layer_data, mode='w' + mode)
def add_file(name, contents):
tar_file_info = tarfile.TarInfo(name=name)
tar_file_info.type = tarfile.REGTYPE
tar_file_info.size = len(contents)
tar_file_info.mtime = 1
tar_file.addfile(tar_file_info, StringIO(contents))
add_file('contents', contents)
if other_files is not None:
for file_name, file_contents in other_files.iteritems():
add_file(file_name, file_contents)
tar_file.close()
layer_bytes = layer_data.getvalue()
layer_data.close()
return layer_bytes
@unique
class Failures(Enum):
""" Defines the various forms of expected failure. """
UNAUTHENTICATED = 'unauthenticated'
UNAUTHORIZED = 'unauthorized'
INVALID_AUTHENTICATION = 'invalid-authentication'
INVALID_REGISTRY = 'invalid-registry'
INVALID_REPOSITORY = 'invalid-repository'
SLASH_REPOSITORY = 'slash-repository'
APP_REPOSITORY = 'app-repository'
UNKNOWN_TAG = 'unknown-tag'
ANONYMOUS_NOT_ALLOWED = 'anonymous-not-allowed'
DISALLOWED_LIBRARY_NAMESPACE = 'disallowed-library-namespace'
MISSING_TAG = 'missing-tag'
INVALID_TAG = 'invalid-tag'
INVALID_MANIFEST = 'invalid-manifest'
INVALID_IMAGES = 'invalid-images'
UNSUPPORTED_CONTENT_TYPE = 'unsupported-content-type'
INVALID_BLOB = 'invalid-blob'
NAMESPACE_DISABLED = 'namespace-disabled'
UNAUTHORIZED_FOR_MOUNT = 'unauthorized-for-mount'
GEO_BLOCKED = 'geo-blocked'
READ_ONLY = 'read-only'
MIRROR_ONLY = 'mirror-only'
MIRROR_MISCONFIGURED = 'mirror-misconfigured'
MIRROR_ROBOT_MISSING = 'mirror-robot-missing'
READONLY_REGISTRY = 'readonly-registry'
class ProtocolOptions(object):
def __init__(self):
self.scopes = None
self.cancel_blob_upload = False
self.manifest_invalid_blob_references = False
self.chunks_for_upload = None
self.skip_head_checks = False
self.manifest_content_type = None
self.accept_mimetypes = None
self.mount_blobs = None
self.push_by_manifest_digest = False
self.request_addr = None
self.skip_blob_push_checks = False
self.ensure_ascii = True
self.attempt_pull_without_token = False
@add_metaclass(ABCMeta)
class RegistryProtocol(object):
""" Interface for protocols. """
FAILURE_CODES = {}
@abstractmethod
def login(self, session, username, password, scopes, expect_success):
""" Performs the login flow with the given credentials, over the given scopes. """
@abstractmethod
def pull(self, session, namespace, repo_name, tag_names, images, credentials=None,
expected_failure=None, options=None):
""" Pulls the given tag via the given session, using the given credentials, and
ensures the given images match.
"""
@abstractmethod
def push(self, session, namespace, repo_name, tag_names, images, credentials=None,
expected_failure=None, options=None):
""" Pushes the specified images as the given tag via the given session, using
the given credentials.
"""
@abstractmethod
def delete(self, session, namespace, repo_name, tag_names, credentials=None,
expected_failure=None, options=None):
""" Deletes some tags. """
def repo_name(self, namespace, repo_name):
if namespace:
return '%s/%s' % (namespace, repo_name)
return repo_name
def conduct(self, session, method, url, expected_status=200, params=None, data=None,
json_data=None, headers=None, auth=None, options=None):
if json_data is not None:
data = json.dumps(json_data).encode('utf-8')
headers = headers or {}
headers['Content-Type'] = 'application/json'
if options and options.request_addr:
headers = headers or {}
headers['X-Override-Remote-Addr-For-Testing'] = options.request_addr
if isinstance(expected_status, tuple):
expected_status, expected_failure, protocol_step = expected_status
if expected_failure is not None:
failures = self.__class__.FAILURE_CODES.get(protocol_step, {})
expected_status = failures.get(expected_failure, expected_status)
result = session.request(method, url, params=params, data=data, headers=headers, auth=auth)
msg = "Expected response %s, got %s: %s" % (expected_status, result.status_code, result.text)
assert result.status_code == expected_status, msg
return result

File diff suppressed because it is too large Load diff

2315
test/registry_tests.py Normal file

File diff suppressed because it is too large Load diff

510
test/specs.py Normal file
View file

@ -0,0 +1,510 @@
import json
import hashlib
from flask import url_for
from base64 import b64encode
NO_REPO = None
PUBLIC = 'public'
PUBLIC_REPO_NAME = 'publicrepo'
PUBLIC_REPO = PUBLIC + '/' + PUBLIC_REPO_NAME
PRIVATE = 'devtable'
PRIVATE_REPO_NAME = 'shared'
PRIVATE_REPO = PRIVATE + '/' + PRIVATE_REPO_NAME
ORG = 'buynlarge'
ORG_REPO = ORG + '/orgrepo'
ANOTHER_ORG_REPO = ORG + '/anotherorgrepo'
NEW_ORG_REPO = ORG + '/neworgrepo'
ORG_REPO_NAME = 'orgrepo'
ORG_READERS = 'readers'
ORG_OWNER = 'devtable'
ORG_OWNERS = 'owners'
ORG_READERS = 'readers'
FAKE_MANIFEST = 'unknown_tag'
FAKE_DIGEST = 'sha256:' + hashlib.sha256('fake').hexdigest()
FAKE_IMAGE_ID = 'fake-image'
FAKE_UPLOAD_ID = 'fake-upload'
FAKE_TAG_NAME = 'fake-tag'
FAKE_USERNAME = 'fakeuser'
FAKE_TOKEN = 'fake-token'
FAKE_WEBHOOK = 'fake-webhook'
BUILD_UUID = '123'
TRIGGER_UUID = '123'
NEW_ORG_REPO_DETAILS = {
'repository': 'fake-repository',
'visibility': 'private',
'description': '',
'namespace': ORG,
}
NEW_USER_DETAILS = {
'username': 'bobby',
'password': 'password',
'email': 'bobby@tables.com',
}
SEND_RECOVERY_DETAILS = {
'email': 'jacob.moshenko@gmail.com',
}
SIGNIN_DETAILS = {
'username': 'devtable',
'password': 'password',
}
FILE_DROP_DETAILS = {
'mimeType': 'application/zip',
}
CHANGE_PERMISSION_DETAILS = {
'role': 'admin',
}
CREATE_BUILD_DETAILS = {
'file_id': 'fake-file-id',
}
CHANGE_VISIBILITY_DETAILS = {
'visibility': 'public',
}
CREATE_TOKEN_DETAILS = {
'friendlyName': 'A new token',
}
UPDATE_REPO_DETAILS = {
'description': 'A new description',
}
class IndexV1TestSpec(object):
def __init__(self, url, sess_repo=None, anon_code=403, no_access_code=403,
read_code=200, creator_code=200, admin_code=200):
self._url = url
self._method = 'GET'
self._data = None
self.sess_repo = sess_repo
self.anon_code = anon_code
self.no_access_code = no_access_code
self.read_code = read_code
self.creator_code = creator_code
self.admin_code = admin_code
def gen_basic_auth(self, username, password):
encoded = b64encode('%s:%s' % (username, password))
return 'basic %s' % encoded
def set_data_from_obj(self, json_serializable):
self._data = json.dumps(json_serializable)
return self
def set_method(self, method):
self._method = method
return self
def get_client_args(self):
kwargs = {
'method': self._method
}
if self._data or self._method == 'POST' or self._method == 'PUT' or self._method == 'PATCH':
kwargs['data'] = self._data if self._data else '{}'
kwargs['content_type'] = 'application/json'
return self._url, kwargs
def build_v1_index_specs():
return [
IndexV1TestSpec(url_for('v1.get_image_layer', image_id=FAKE_IMAGE_ID),
PUBLIC_REPO, 404, 404, 404, 404, 404),
IndexV1TestSpec(url_for('v1.get_image_layer', image_id=FAKE_IMAGE_ID),
PRIVATE_REPO, 403, 403, 404, 403, 404),
IndexV1TestSpec(url_for('v1.get_image_layer', image_id=FAKE_IMAGE_ID),
ORG_REPO, 403, 403, 404, 403, 404),
IndexV1TestSpec(url_for('v1.get_image_layer', image_id=FAKE_IMAGE_ID),
ANOTHER_ORG_REPO, 403, 403, 403, 403, 404),
IndexV1TestSpec(url_for('v1.put_image_layer', image_id=FAKE_IMAGE_ID),
PUBLIC_REPO, 403, 403, 403, 403, 403).set_method('PUT'),
IndexV1TestSpec(url_for('v1.put_image_layer', image_id=FAKE_IMAGE_ID),
PRIVATE_REPO, 403, 403, 403, 403, 400).set_method('PUT'),
IndexV1TestSpec(url_for('v1.put_image_layer', image_id=FAKE_IMAGE_ID),
ORG_REPO, 403, 403, 403, 403, 400).set_method('PUT'),
IndexV1TestSpec(url_for('v1.put_image_layer', image_id=FAKE_IMAGE_ID),
ANOTHER_ORG_REPO, 403, 403, 403, 403, 400).set_method('PUT'),
IndexV1TestSpec(url_for('v1.put_image_checksum',
image_id=FAKE_IMAGE_ID),
PUBLIC_REPO, 403, 403, 403, 403, 403).set_method('PUT'),
IndexV1TestSpec(url_for('v1.put_image_checksum',
image_id=FAKE_IMAGE_ID),
PRIVATE_REPO, 403, 403, 403, 403, 400).set_method('PUT'),
IndexV1TestSpec(url_for('v1.put_image_checksum',
image_id=FAKE_IMAGE_ID),
ORG_REPO, 403, 403, 403, 403, 400).set_method('PUT'),
IndexV1TestSpec(url_for('v1.put_image_checksum',
image_id=FAKE_IMAGE_ID),
ANOTHER_ORG_REPO, 403, 403, 403, 403, 400).set_method('PUT'),
IndexV1TestSpec(url_for('v1.get_image_json', image_id=FAKE_IMAGE_ID),
PUBLIC_REPO, 404, 404, 404, 404, 404),
IndexV1TestSpec(url_for('v1.get_image_json', image_id=FAKE_IMAGE_ID),
PRIVATE_REPO, 403, 403, 404, 403, 404),
IndexV1TestSpec(url_for('v1.get_image_json', image_id=FAKE_IMAGE_ID),
ORG_REPO, 403, 403, 404, 403, 404),
IndexV1TestSpec(url_for('v1.get_image_json', image_id=FAKE_IMAGE_ID),
ANOTHER_ORG_REPO, 403, 403, 403, 403, 404),
IndexV1TestSpec(url_for('v1.get_image_ancestry', image_id=FAKE_IMAGE_ID),
PUBLIC_REPO, 404, 404, 404, 404, 404),
IndexV1TestSpec(url_for('v1.get_image_ancestry', image_id=FAKE_IMAGE_ID),
PRIVATE_REPO, 403, 403, 404, 403, 404),
IndexV1TestSpec(url_for('v1.get_image_ancestry', image_id=FAKE_IMAGE_ID),
ORG_REPO, 403, 403, 404, 403, 404),
IndexV1TestSpec(url_for('v1.get_image_ancestry', image_id=FAKE_IMAGE_ID),
ANOTHER_ORG_REPO, 403, 403, 403, 403, 404),
IndexV1TestSpec(url_for('v1.put_image_json', image_id=FAKE_IMAGE_ID),
PUBLIC_REPO, 403, 403, 403, 403, 403).set_method('PUT'),
IndexV1TestSpec(url_for('v1.put_image_json', image_id=FAKE_IMAGE_ID),
PRIVATE_REPO, 403, 403, 403, 403, 400).set_method('PUT'),
IndexV1TestSpec(url_for('v1.put_image_json', image_id=FAKE_IMAGE_ID),
ORG_REPO, 403, 403, 403, 403, 400).set_method('PUT'),
IndexV1TestSpec(url_for('v1.put_image_json', image_id=FAKE_IMAGE_ID),
ANOTHER_ORG_REPO, 403, 403, 403, 403, 400).set_method('PUT'),
IndexV1TestSpec(url_for('v1.create_user'), NO_REPO, 400, 400, 400, 400,
400).set_method('POST').set_data_from_obj(NEW_USER_DETAILS),
IndexV1TestSpec(url_for('v1.get_user'), NO_REPO, 404, 200, 200, 200, 200),
IndexV1TestSpec(url_for('v1.update_user', username=FAKE_USERNAME),
NO_REPO, 403, 403, 403, 403, 403).set_method('PUT'),
IndexV1TestSpec(url_for('v1.create_repository', repository=PUBLIC_REPO),
NO_REPO, 403, 403, 403, 403, 403).set_method('PUT'),
IndexV1TestSpec(url_for('v1.create_repository', repository=PRIVATE_REPO),
NO_REPO, 403, 403, 403, 403, 201).set_method('PUT'),
IndexV1TestSpec(url_for('v1.create_repository', repository=ORG_REPO),
NO_REPO, 403, 403, 403, 403, 201).set_method('PUT'),
IndexV1TestSpec(url_for('v1.create_repository', repository=ANOTHER_ORG_REPO),
NO_REPO, 403, 403, 403, 403, 201).set_method('PUT'),
IndexV1TestSpec(url_for('v1.create_repository', repository=NEW_ORG_REPO),
NO_REPO, 401, 403, 403, 201, 201).set_method('PUT'),
IndexV1TestSpec(url_for('v1.update_images', repository=PUBLIC_REPO),
NO_REPO, 403, 403, 403, 403, 403).set_method('PUT'),
IndexV1TestSpec(url_for('v1.update_images', repository=PRIVATE_REPO),
NO_REPO, 403, 403, 403, 403, 400).set_method('PUT'),
IndexV1TestSpec(url_for('v1.update_images', repository=ORG_REPO), NO_REPO,
403, 403, 403, 403, 400).set_method('PUT'),
IndexV1TestSpec(url_for('v1.update_images', repository=ANOTHER_ORG_REPO), NO_REPO,
403, 403, 403, 403, 400).set_method('PUT'),
IndexV1TestSpec(url_for('v1.get_repository_images',
repository=PUBLIC_REPO),
NO_REPO, 200, 200, 200, 200, 200),
IndexV1TestSpec(url_for('v1.get_repository_images',
repository=PRIVATE_REPO),
NO_REPO, 403, 403, 200, 403, 200),
IndexV1TestSpec(url_for('v1.get_repository_images',
repository=ORG_REPO),
NO_REPO, 403, 403, 200, 403, 200),
IndexV1TestSpec(url_for('v1.get_repository_images',
repository=ANOTHER_ORG_REPO),
NO_REPO, 403, 403, 403, 403, 200),
IndexV1TestSpec(url_for('v1.delete_repository_images',
repository=PUBLIC_REPO),
NO_REPO, 501, 501, 501, 501, 501).set_method('DELETE'),
IndexV1TestSpec(url_for('v1.put_repository_auth', repository=PUBLIC_REPO),
NO_REPO, 501, 501, 501, 501, 501).set_method('PUT'),
IndexV1TestSpec(url_for('v1.get_search'), NO_REPO, 200, 200, 200, 200, 200),
IndexV1TestSpec(url_for('v1.ping'), NO_REPO, 200, 200, 200, 200, 200),
IndexV1TestSpec(url_for('v1.get_tags', repository=PUBLIC_REPO), NO_REPO,
200, 200, 200, 200, 200),
IndexV1TestSpec(url_for('v1.get_tags', repository=PRIVATE_REPO), NO_REPO,
403, 403, 200, 403, 200),
IndexV1TestSpec(url_for('v1.get_tags', repository=ORG_REPO), NO_REPO,
403, 403, 200, 403, 200),
IndexV1TestSpec(url_for('v1.get_tags', repository=ANOTHER_ORG_REPO), NO_REPO,
403, 403, 403, 403, 200),
IndexV1TestSpec(url_for('v1.get_tag', repository=PUBLIC_REPO,
tag=FAKE_TAG_NAME), NO_REPO, 404, 404, 404, 404, 404),
IndexV1TestSpec(url_for('v1.get_tag', repository=PRIVATE_REPO,
tag=FAKE_TAG_NAME), NO_REPO, 403, 403, 404, 403, 404),
IndexV1TestSpec(url_for('v1.get_tag', repository=ORG_REPO,
tag=FAKE_TAG_NAME), NO_REPO, 403, 403, 404, 403, 404),
IndexV1TestSpec(url_for('v1.get_tag', repository=ANOTHER_ORG_REPO,
tag=FAKE_TAG_NAME), NO_REPO, 403, 403, 403, 403, 404),
IndexV1TestSpec(url_for('v1.put_tag', repository=PUBLIC_REPO,
tag=FAKE_TAG_NAME),
NO_REPO, 403, 403, 403, 403, 403).set_method('PUT'),
IndexV1TestSpec(url_for('v1.put_tag', repository=PRIVATE_REPO,
tag=FAKE_TAG_NAME),
NO_REPO, 403, 403, 403, 403, 400).set_method('PUT'),
IndexV1TestSpec(url_for('v1.put_tag', repository=ORG_REPO,
tag=FAKE_TAG_NAME),
NO_REPO, 403, 403, 403, 403, 400).set_method('PUT'),
IndexV1TestSpec(url_for('v1.put_tag', repository=ANOTHER_ORG_REPO,
tag=FAKE_TAG_NAME),
NO_REPO, 403, 403, 403, 403, 400).set_method('PUT'),
IndexV1TestSpec(url_for('v1.delete_tag', repository=PUBLIC_REPO,
tag=FAKE_TAG_NAME),
NO_REPO, 403, 403, 403, 403, 403).set_method('DELETE'),
IndexV1TestSpec(url_for('v1.delete_tag', repository=PRIVATE_REPO,
tag=FAKE_TAG_NAME),
NO_REPO, 403, 403, 403, 403, 400).set_method('DELETE'),
IndexV1TestSpec(url_for('v1.delete_tag', repository=ORG_REPO,
tag=FAKE_TAG_NAME),
NO_REPO, 403, 403, 403, 403, 400).set_method('DELETE'),
IndexV1TestSpec(url_for('v1.delete_tag', repository=ANOTHER_ORG_REPO,
tag=FAKE_TAG_NAME),
NO_REPO, 403, 403, 403, 403, 400).set_method('DELETE'),
]
class IndexV2TestSpec(object):
def __init__(self, index_name, method_name, repo_name, scope=None, **kwargs):
self.index_name = index_name
self.repo_name = repo_name
self.method_name = method_name
default_scope = 'push,pull' if method_name != 'GET' and method_name != 'HEAD' else 'pull'
self.scope = scope or default_scope
self.kwargs = kwargs
self.anon_code = 401
self.no_access_code = 403
self.read_code = 200
self.admin_code = 200
self.creator_code = 200
def request_status(self, anon_code=401, no_access_code=403, read_code=200, creator_code=200,
admin_code=200):
self.anon_code = anon_code
self.no_access_code = no_access_code
self.read_code = read_code
self.creator_code = creator_code
self.admin_code = admin_code
return self
def get_url(self):
return url_for(self.index_name, repository=self.repo_name, **self.kwargs)
def gen_basic_auth(self, username, password):
encoded = b64encode('%s:%s' % (username, password))
return 'basic %s' % encoded
def get_scope_string(self):
return 'repository:%s:%s' % (self.repo_name, self.scope)
def build_v2_index_specs():
return [
# v2.list_all_tags
IndexV2TestSpec('v2.list_all_tags', 'GET', PUBLIC_REPO).
request_status(200, 200, 200, 200, 200),
IndexV2TestSpec('v2.list_all_tags', 'GET', PRIVATE_REPO).
request_status(401, 401, 200, 401, 200),
IndexV2TestSpec('v2.list_all_tags', 'GET', ORG_REPO).
request_status(401, 401, 200, 401, 200),
IndexV2TestSpec('v2.list_all_tags', 'GET', ANOTHER_ORG_REPO).
request_status(401, 401, 401, 401, 200),
# v2.fetch_manifest_by_tagname
IndexV2TestSpec('v2.fetch_manifest_by_tagname', 'GET', PUBLIC_REPO, manifest_ref=FAKE_MANIFEST).
request_status(404, 404, 404, 404, 404),
IndexV2TestSpec('v2.fetch_manifest_by_tagname', 'GET', PRIVATE_REPO,
manifest_ref=FAKE_MANIFEST).
request_status(401, 401, 404, 401, 404),
IndexV2TestSpec('v2.fetch_manifest_by_tagname', 'GET', ORG_REPO, manifest_ref=FAKE_MANIFEST).
request_status(401, 401, 404, 401, 404),
IndexV2TestSpec('v2.fetch_manifest_by_tagname', 'GET', ANOTHER_ORG_REPO,
manifest_ref=FAKE_MANIFEST).
request_status(401, 401, 401, 401, 404),
# v2.fetch_manifest_by_digest
IndexV2TestSpec('v2.fetch_manifest_by_digest', 'GET', PUBLIC_REPO, manifest_ref=FAKE_DIGEST).
request_status(404, 404, 404, 404, 404),
IndexV2TestSpec('v2.fetch_manifest_by_digest', 'GET', PRIVATE_REPO, manifest_ref=FAKE_DIGEST).
request_status(401, 401, 404, 401, 404),
IndexV2TestSpec('v2.fetch_manifest_by_digest', 'GET', ORG_REPO, manifest_ref=FAKE_DIGEST).
request_status(401, 401, 404, 401, 404),
IndexV2TestSpec('v2.fetch_manifest_by_digest', 'GET', ANOTHER_ORG_REPO,
manifest_ref=FAKE_DIGEST).
request_status(401, 401, 401, 401, 404),
# v2.write_manifest_by_tagname
IndexV2TestSpec('v2.write_manifest_by_tagname', 'PUT', PUBLIC_REPO, manifest_ref=FAKE_MANIFEST).
request_status(401, 401, 401, 401, 401),
IndexV2TestSpec('v2.write_manifest_by_tagname', 'PUT', PRIVATE_REPO,
manifest_ref=FAKE_MANIFEST).
request_status(401, 401, 401, 401, 400),
IndexV2TestSpec('v2.write_manifest_by_tagname', 'PUT', ORG_REPO, manifest_ref=FAKE_MANIFEST).
request_status(401, 401, 401, 401, 400),
IndexV2TestSpec('v2.write_manifest_by_tagname', 'PUT', ANOTHER_ORG_REPO,
manifest_ref=FAKE_MANIFEST).
request_status(401, 401, 401, 401, 400),
# v2.write_manifest_by_digest
IndexV2TestSpec('v2.write_manifest_by_digest', 'PUT', PUBLIC_REPO, manifest_ref=FAKE_DIGEST).
request_status(401, 401, 401, 401, 401),
IndexV2TestSpec('v2.write_manifest_by_digest', 'PUT', PRIVATE_REPO, manifest_ref=FAKE_DIGEST).
request_status(401, 401, 401, 401, 400),
IndexV2TestSpec('v2.write_manifest_by_digest', 'PUT', ORG_REPO, manifest_ref=FAKE_DIGEST).
request_status(401, 401, 401, 401, 400),
IndexV2TestSpec('v2.write_manifest_by_digest', 'PUT', ANOTHER_ORG_REPO,
manifest_ref=FAKE_DIGEST).
request_status(401, 401, 401, 401, 400),
# v2.delete_manifest_by_digest
IndexV2TestSpec('v2.delete_manifest_by_digest', 'DELETE', PUBLIC_REPO,
manifest_ref=FAKE_DIGEST).
request_status(401, 401, 401, 401, 401),
IndexV2TestSpec('v2.delete_manifest_by_digest', 'DELETE', PRIVATE_REPO,
manifest_ref=FAKE_DIGEST).
request_status(401, 401, 401, 401, 404),
IndexV2TestSpec('v2.delete_manifest_by_digest', 'DELETE', ORG_REPO, manifest_ref=FAKE_DIGEST).
request_status(401, 401, 401, 401, 404),
IndexV2TestSpec('v2.delete_manifest_by_digest', 'DELETE', ANOTHER_ORG_REPO,
manifest_ref=FAKE_DIGEST).
request_status(401, 401, 401, 401, 404),
# v2.check_blob_exists
IndexV2TestSpec('v2.check_blob_exists', 'HEAD', PUBLIC_REPO, digest=FAKE_DIGEST).
request_status(404, 404, 404, 404, 404),
IndexV2TestSpec('v2.check_blob_exists', 'HEAD', PRIVATE_REPO, digest=FAKE_DIGEST).
request_status(401, 401, 404, 401, 404),
IndexV2TestSpec('v2.check_blob_exists', 'HEAD', ORG_REPO, digest=FAKE_DIGEST).
request_status(401, 401, 404, 401, 404),
IndexV2TestSpec('v2.check_blob_exists', 'HEAD', ANOTHER_ORG_REPO, digest=FAKE_DIGEST).
request_status(401, 401, 401, 401, 404),
# v2.download_blob
IndexV2TestSpec('v2.download_blob', 'GET', PUBLIC_REPO, digest=FAKE_DIGEST).
request_status(404, 404, 404, 404, 404),
IndexV2TestSpec('v2.download_blob', 'GET', PRIVATE_REPO, digest=FAKE_DIGEST).
request_status(401, 401, 404, 401, 404),
IndexV2TestSpec('v2.download_blob', 'GET', ORG_REPO, digest=FAKE_DIGEST).
request_status(401, 401, 404, 401, 404),
IndexV2TestSpec('v2.download_blob', 'GET', ANOTHER_ORG_REPO, digest=FAKE_DIGEST).
request_status(401, 401, 401, 401, 404),
# v2.start_blob_upload
IndexV2TestSpec('v2.start_blob_upload', 'POST', PUBLIC_REPO).
request_status(401, 401, 401, 401, 401),
IndexV2TestSpec('v2.start_blob_upload', 'POST', PRIVATE_REPO).
request_status(401, 401, 401, 401, 202),
IndexV2TestSpec('v2.start_blob_upload', 'POST', ORG_REPO).
request_status(401, 401, 401, 401, 202),
IndexV2TestSpec('v2.start_blob_upload', 'POST', ANOTHER_ORG_REPO).
request_status(401, 401, 401, 401, 202),
# v2.fetch_existing_upload
IndexV2TestSpec('v2.fetch_existing_upload', 'GET', PUBLIC_REPO, 'push,pull',
upload_uuid=FAKE_UPLOAD_ID).
request_status(401, 401, 401, 401, 401),
IndexV2TestSpec('v2.fetch_existing_upload', 'GET', PRIVATE_REPO, 'push,pull',
upload_uuid=FAKE_UPLOAD_ID).
request_status(401, 401, 401, 401, 404),
IndexV2TestSpec('v2.fetch_existing_upload', 'GET', ORG_REPO, 'push,pull',
upload_uuid=FAKE_UPLOAD_ID).
request_status(401, 401, 401, 401, 404),
IndexV2TestSpec('v2.fetch_existing_upload', 'GET', ANOTHER_ORG_REPO, 'push,pull',
upload_uuid=FAKE_UPLOAD_ID).
request_status(401, 401, 401, 401, 404),
# v2.upload_chunk
IndexV2TestSpec('v2.upload_chunk', 'PATCH', PUBLIC_REPO, upload_uuid=FAKE_UPLOAD_ID).
request_status(401, 401, 401, 401, 401),
IndexV2TestSpec('v2.upload_chunk', 'PATCH', PRIVATE_REPO, upload_uuid=FAKE_UPLOAD_ID).
request_status(401, 401, 401, 401, 404),
IndexV2TestSpec('v2.upload_chunk', 'PATCH', ORG_REPO, upload_uuid=FAKE_UPLOAD_ID).
request_status(401, 401, 401, 401, 404),
IndexV2TestSpec('v2.upload_chunk', 'PATCH', ANOTHER_ORG_REPO, upload_uuid=FAKE_UPLOAD_ID).
request_status(401, 401, 401, 401, 404),
# v2.monolithic_upload_or_last_chunk
IndexV2TestSpec('v2.monolithic_upload_or_last_chunk', 'PUT', PUBLIC_REPO,
upload_uuid=FAKE_UPLOAD_ID).
request_status(401, 401, 401, 401, 401),
IndexV2TestSpec('v2.monolithic_upload_or_last_chunk', 'PUT', PRIVATE_REPO,
upload_uuid=FAKE_UPLOAD_ID).
request_status(401, 401, 401, 401, 400),
IndexV2TestSpec('v2.monolithic_upload_or_last_chunk', 'PUT', ORG_REPO,
upload_uuid=FAKE_UPLOAD_ID).
request_status(401, 401, 401, 401, 400),
IndexV2TestSpec('v2.monolithic_upload_or_last_chunk', 'PUT', ANOTHER_ORG_REPO,
upload_uuid=FAKE_UPLOAD_ID).
request_status(401, 401, 401, 401, 400),
# v2.cancel_upload
IndexV2TestSpec('v2.cancel_upload', 'DELETE', PUBLIC_REPO, upload_uuid=FAKE_UPLOAD_ID).
request_status(401, 401, 401, 401, 401),
IndexV2TestSpec('v2.cancel_upload', 'DELETE', PRIVATE_REPO, upload_uuid=FAKE_UPLOAD_ID).
request_status(401, 401, 401, 401, 404),
IndexV2TestSpec('v2.cancel_upload', 'DELETE', ORG_REPO, upload_uuid=FAKE_UPLOAD_ID).
request_status(401, 401, 401, 401, 404),
IndexV2TestSpec('v2.cancel_upload', 'DELETE', ANOTHER_ORG_REPO, upload_uuid=FAKE_UPLOAD_ID).
request_status(401, 401, 401, 401, 404),
]

4345
test/test_api_usage.py Normal file

File diff suppressed because it is too large Load diff

69
test/test_certs_install.sh Executable file
View file

@ -0,0 +1,69 @@
#!/usr/bin/env bash
set -e
echo "> Starting certs install test"
# Set up all locations needed for the test
QUAYPATH=${QUAYPATH:-"."}
SCRIPT_LOCATION=${SCRIPT_LOCATION:-"/quay-registry/conf/init"}
# Parameters: (quay config dir, certifcate dir, number of certs expected).
function call_script_and_check_num_certs {
QUAYCONFIG=$1 CERTDIR=$2 ${SCRIPT_LOCATION}/certs_install.sh
if [ $? -ne 0 ]; then
echo "Failed to install $3 certs"
exit 1;
fi
certs_found=$(ls /etc/pki/ca-trust/source/anchors | wc -l)
if [ ${certs_found} -ne "$3" ]; then
echo "Expected there to be $3 in ca-certificates, found $certs_found"
exit 1
fi
}
# Create a dummy cert we can test to install
# echo '{"CN":"CA","key":{"algo":"rsa","size":2048}}' | cfssl gencert -initca - | cfssljson -bare test
openssl req -new -newkey rsa:4096 -days 3650 -nodes -x509 \
-subj "/C=US/ST=NY/L=NYC/O=Dis/CN=self-signed" \
-keyout test-key.pem -out test.pem
# Create temp dirs we can test with
WORK_DIR=`mktemp -d`
CERTS_WORKDIR=`mktemp -d`
# deletes the temp directory
function cleanup {
rm -rf "$WORK_DIR"
rm -rf "$CERTS_WORKDIR"
rm test.pem
rm test-key.pem
}
# register the cleanup function to be called on the EXIT signal
trap cleanup EXIT
# Test calling with empty directory to not fail
call_script_and_check_num_certs ${WORK_DIR} ${CERTS_WORKDIR} 0
if [ "$?" -ne 0 ]; then
echo "Failed to install certs with no files in the directory"
exit 1
fi
# Move an ldap cert into the temp directory and test that installation
cp test.pem ${WORK_DIR}/ldap.crt
call_script_and_check_num_certs ${WORK_DIR} ${CERTS_WORKDIR} 1
# Move 1 cert to extra cert dir and test
cp test.pem ${CERTS_WORKDIR}/cert1.crt
call_script_and_check_num_certs ${WORK_DIR} ${CERTS_WORKDIR} 2
# Move another cert to extra cer dir and test all three exist
cp test.pem ${CERTS_WORKDIR}/cert2.crt
call_script_and_check_num_certs ${WORK_DIR} ${CERTS_WORKDIR} 3
echo "> Certs install script test succeeded"
exit 0

701
test/test_endpoints.py Normal file
View file

@ -0,0 +1,701 @@
# coding=utf-8
import json as py_json
import time
import unittest
import base64
import zlib
from mock import patch
from io import BytesIO
from urllib import urlencode
from urlparse import urlparse, urlunparse, parse_qs
from datetime import datetime, timedelta
import jwt
from Crypto.PublicKey import RSA
from flask import url_for
from jwkest.jwk import RSAKey
from app import app
from data import model
from data.database import ServiceKeyApprovalType
from endpoints import keyserver
from endpoints.api import api, api_bp
from endpoints.api.user import Signin
from endpoints.keyserver import jwk_with_kid
from endpoints.csrf import OAUTH_CSRF_TOKEN_NAME
from endpoints.web import web as web_bp
from endpoints.webhooks import webhooks as webhooks_bp
from initdb import setup_database_for_testing, finished_database_for_testing
from test.helpers import assert_action_logged
from util.security.token import encode_public_private_token
from util.registry.gzipinputstream import WINDOW_BUFFER_SIZE
try:
app.register_blueprint(web_bp, url_prefix='')
except ValueError:
# This blueprint was already registered
pass
try:
app.register_blueprint(webhooks_bp, url_prefix='/webhooks')
except ValueError:
# This blueprint was already registered
pass
try:
app.register_blueprint(keyserver.key_server, url_prefix='')
except ValueError:
# This blueprint was already registered
pass
try:
app.register_blueprint(api_bp, url_prefix='/api')
except ValueError:
# This blueprint was already registered
pass
CSRF_TOKEN_KEY = '_csrf_token'
CSRF_TOKEN = '123csrfforme'
class EndpointTestCase(unittest.TestCase):
maxDiff = None
def _add_csrf(self, without_csrf):
parts = urlparse(without_csrf)
query = parse_qs(parts[4])
self._set_csrf()
query[CSRF_TOKEN_KEY] = CSRF_TOKEN
return urlunparse(list(parts[0:4]) + [urlencode(query)] + list(parts[5:]))
def _set_csrf(self):
with self.app.session_transaction() as sess:
sess[CSRF_TOKEN_KEY] = CSRF_TOKEN
sess[OAUTH_CSRF_TOKEN_NAME] = 'someoauthtoken'
def setUp(self):
setup_database_for_testing(self)
self.app = app.test_client()
self.ctx = app.test_request_context()
self.ctx.__enter__()
def tearDown(self):
finished_database_for_testing(self)
self.ctx.__exit__(True, None, None)
def getResponse(self, resource_name, expected_code=200, **kwargs):
rv = self.app.get(url_for(resource_name, **kwargs))
self.assertEquals(rv.status_code, expected_code)
return rv.data
def deleteResponse(self, resource_name, headers=None, expected_code=200, **kwargs):
headers = headers or {}
rv = self.app.delete(url_for(resource_name, **kwargs), headers=headers)
self.assertEquals(rv.status_code, expected_code)
return rv.data
def deleteEmptyResponse(self, resource_name, headers=None, expected_code=204, **kwargs):
headers = headers or {}
rv = self.app.delete(url_for(resource_name, **kwargs), headers=headers)
self.assertEquals(rv.status_code, expected_code)
self.assertEquals(rv.data, '') # ensure response body empty
return
def putResponse(self, resource_name, headers=None, data=None, expected_code=200, **kwargs):
headers = headers or {}
data = data or {}
rv = self.app.put(url_for(resource_name, **kwargs), headers=headers, data=py_json.dumps(data))
self.assertEquals(rv.status_code, expected_code)
return rv.data
def postResponse(self, resource_name, headers=None, data=None, form=None, with_csrf=True, expected_code=200, **kwargs):
headers = headers or {}
form = form or {}
url = url_for(resource_name, **kwargs)
if with_csrf:
url = self._add_csrf(url)
post_data = None
if form:
post_data = form
elif data:
post_data = py_json.dumps(data)
rv = self.app.post(url, headers=headers, data=post_data)
if expected_code is not None:
self.assertEquals(rv.status_code, expected_code)
return rv
def login(self, username, password):
rv = self.app.post(self._add_csrf(api.url_for(Signin)),
data=py_json.dumps(dict(username=username, password=password)),
headers={"Content-Type": "application/json"})
self.assertEquals(rv.status_code, 200)
class BuildLogsTestCase(EndpointTestCase):
build_uuid = 'deadpork-dead-pork-dead-porkdeadpork'
def test_buildlogs_invalid_build_uuid(self):
self.login('public', 'password')
self.getResponse('web.buildlogs', build_uuid='bad_build_uuid', expected_code=400)
def test_buildlogs_not_logged_in(self):
self.getResponse('web.buildlogs', build_uuid=self.build_uuid, expected_code=403)
def test_buildlogs_unauthorized(self):
self.login('reader', 'password')
self.getResponse('web.buildlogs', build_uuid=self.build_uuid, expected_code=403)
def test_buildlogs_logsarchived(self):
self.login('public', 'password')
with patch('data.model.build.RepositoryBuild', logs_archived=True):
self.getResponse('web.buildlogs', build_uuid=self.build_uuid, expected_code=403)
def test_buildlogs_successful(self):
self.login('public', 'password')
logs = ['log1', 'log2']
with patch('endpoints.web.build_logs.get_log_entries', return_value=(None, logs) ):
resp = self.getResponse('web.buildlogs', build_uuid=self.build_uuid, expected_code=200)
self.assertEquals({"logs": logs}, py_json.loads(resp))
class ArchivedLogsTestCase(EndpointTestCase):
build_uuid = 'deadpork-dead-pork-dead-porkdeadpork'
def test_logarchive_invalid_build_uuid(self):
self.login('public', 'password')
self.getResponse('web.logarchive', file_id='bad_build_uuid', expected_code=403)
def test_logarchive_not_logged_in(self):
self.getResponse('web.logarchive', file_id=self.build_uuid, expected_code=403)
def test_logarchive_unauthorized(self):
self.login('reader', 'password')
self.getResponse('web.logarchive', file_id=self.build_uuid, expected_code=403)
def test_logarchive_file_not_found(self):
self.login('public', 'password')
self.getResponse('web.logarchive', file_id=self.build_uuid, expected_code=403)
def test_logarchive_successful(self):
self.login('public', 'password')
data = b"my_file_stream"
mock_file = BytesIO(zlib.compressobj(-1, zlib.DEFLATED, WINDOW_BUFFER_SIZE).compress(data))
with patch('endpoints.web.log_archive._storage.stream_read_file', return_value=mock_file):
self.getResponse('web.logarchive', file_id=self.build_uuid, expected_code=200)
class WebhookEndpointTestCase(EndpointTestCase):
def test_invalid_build_trigger_webhook(self):
self.postResponse('webhooks.build_trigger_webhook', trigger_uuid='invalidtrigger',
expected_code=404)
def test_valid_build_trigger_webhook_invalid_auth(self):
trigger = list(model.build.list_build_triggers('devtable', 'building'))[0]
self.postResponse('webhooks.build_trigger_webhook', trigger_uuid=trigger.uuid,
expected_code=403)
def test_valid_build_trigger_webhook_cookie_auth(self):
self.login('devtable', 'password')
# Cookie auth is not supported, so this should 403
trigger = list(model.build.list_build_triggers('devtable', 'building'))[0]
self.postResponse('webhooks.build_trigger_webhook', trigger_uuid=trigger.uuid,
expected_code=403)
def test_valid_build_trigger_webhook_missing_payload(self):
auth_header = 'Basic %s' % (base64.b64encode('devtable:password'))
trigger = list(model.build.list_build_triggers('devtable', 'building'))[0]
self.postResponse('webhooks.build_trigger_webhook', trigger_uuid=trigger.uuid,
expected_code=400, headers={'Authorization': auth_header})
def test_valid_build_trigger_webhook_invalid_payload(self):
auth_header = 'Basic %s' % (base64.b64encode('devtable:password'))
trigger = list(model.build.list_build_triggers('devtable', 'building'))[0]
self.postResponse('webhooks.build_trigger_webhook', trigger_uuid=trigger.uuid,
expected_code=400,
headers={'Authorization': auth_header, 'Content-Type': 'application/json'},
data={'invalid': 'payload'})
class WebEndpointTestCase(EndpointTestCase):
def test_index(self):
self.getResponse('web.index')
def test_robots(self):
self.getResponse('web.robots')
def test_repo_view(self):
self.getResponse('web.repository', path='devtable/simple')
def test_unicode_repo_view(self):
self.getResponse('web.repository', path='%E2%80%8Bcoreos/hyperkube%E2%80%8B')
def test_org_view(self):
self.getResponse('web.org_view', path='buynlarge')
def test_user_view(self):
self.getResponse('web.user_view', path='devtable')
def test_confirm_repo_email(self):
code = model.repository.create_email_authorization_for_repo('devtable', 'simple', 'foo@bar.com')
self.getResponse('web.confirm_repo_email', code=code.code)
found = model.repository.get_email_authorized_for_repo('devtable', 'simple', 'foo@bar.com')
self.assertTrue(found.confirmed)
def test_confirm_email(self):
user = model.user.get_user('devtable')
self.assertNotEquals(user.email, 'foo@bar.com')
confirmation_code = model.user.create_confirm_email_code(user, 'foo@bar.com')
self.getResponse('web.confirm_email', code=confirmation_code, expected_code=302)
user = model.user.get_user('devtable')
self.assertEquals(user.email, 'foo@bar.com')
def test_confirm_recovery(self):
# Try for an invalid code.
self.getResponse('web.confirm_recovery', code='someinvalidcode', expected_code=200)
# Create a valid code and try.
user = model.user.get_user('devtable')
confirmation_code = model.user.create_reset_password_email_code(user.email)
self.getResponse('web.confirm_recovery', code=confirmation_code, expected_code=302)
def test_confirm_recovery_verified(self):
# Create a valid code and try.
user = model.user.get_user('devtable')
user.verified = False
user.save()
confirmation_code = model.user.create_reset_password_email_code(user.email)
self.getResponse('web.confirm_recovery', code=confirmation_code, expected_code=302)
# Ensure the current user is the expected user and that they are verified.
user = model.user.get_user('devtable')
self.assertTrue(user.verified)
self.getResponse('web.receipt', expected_code=404) # Will 401 if no user.
def test_request_authorization_code(self):
# Try for an invalid client.
self.getResponse('web.request_authorization_code', client_id='foo', redirect_uri='bar',
scope='baz', expected_code=404)
# Try for a valid client.
org = model.organization.get_organization('buynlarge')
assert org
app = model.oauth.create_application(org, 'test', 'http://foo/bar', 'http://foo/bar/baz')
self.getResponse('web.request_authorization_code',
client_id=app.client_id,
redirect_uri=app.redirect_uri,
scope='repo:read',
expected_code=200)
def test_build_status_badge(self):
# Try for an invalid repository.
self.getResponse('web.build_status_badge', repository='foo/bar', expected_code=404)
# Try for a public repository.
self.getResponse('web.build_status_badge', repository='public/publicrepo')
# Try for an private repository.
self.getResponse('web.build_status_badge', repository='devtable/simple',
expected_code=404)
# Try for an private repository with an invalid token.
self.getResponse('web.build_status_badge', repository='devtable/simple',
token='sometoken', expected_code=404)
# Try for an private repository with a valid token.
repository = model.repository.get_repository('devtable', 'simple')
self.getResponse('web.build_status_badge', repository='devtable/simple',
token=repository.badge_token)
def test_attach_custom_build_trigger(self):
self.getResponse('web.attach_custom_build_trigger', repository='foo/bar', expected_code=401)
self.getResponse('web.attach_custom_build_trigger', repository='devtable/simple', expected_code=401)
self.login('freshuser', 'password')
self.getResponse('web.attach_custom_build_trigger', repository='devtable/simple', expected_code=403)
self.login('devtable', 'password')
self.getResponse('web.attach_custom_build_trigger', repository='devtable/simple', expected_code=302)
def test_redirect_to_repository(self):
self.getResponse('web.redirect_to_repository', repository='foo/bar', expected_code=404)
self.getResponse('web.redirect_to_repository', repository='public/publicrepo', expected_code=302)
self.getResponse('web.redirect_to_repository', repository='devtable/simple', expected_code=403)
self.login('devtable', 'password')
self.getResponse('web.redirect_to_repository', repository='devtable/simple', expected_code=302)
def test_redirect_to_namespace(self):
self.getResponse('web.redirect_to_namespace', namespace='unknown', expected_code=404)
self.getResponse('web.redirect_to_namespace', namespace='devtable', expected_code=302)
self.getResponse('web.redirect_to_namespace', namespace='buynlarge', expected_code=302)
class OAuthTestCase(EndpointTestCase):
def test_authorize_nologin(self):
form = {
'client_id': 'someclient',
'redirect_uri': 'http://localhost:5000/foobar',
'scope': 'user:admin',
}
self.postResponse('web.authorize_application', form=form, with_csrf=True, expected_code=401)
def test_authorize_invalidclient(self):
self.login('devtable', 'password')
form = {
'client_id': 'someclient',
'redirect_uri': 'http://localhost:5000/foobar',
'scope': 'user:admin',
}
resp = self.postResponse('web.authorize_application', form=form, with_csrf=True, expected_code=302)
self.assertEquals('http://localhost:5000/foobar?error=unauthorized_client', resp.headers['Location'])
def test_authorize_invalidscope(self):
self.login('devtable', 'password')
form = {
'client_id': 'deadbeef',
'redirect_uri': 'http://localhost:8000/o2c.html',
'scope': 'invalid:scope',
}
resp = self.postResponse('web.authorize_application', form=form, with_csrf=True, expected_code=302)
self.assertEquals('http://localhost:8000/o2c.html?error=invalid_scope', resp.headers['Location'])
def test_authorize_invalidredirecturi(self):
self.login('devtable', 'password')
# Note: Defined in initdb.py
form = {
'client_id': 'deadbeef',
'redirect_uri': 'http://some/invalid/uri',
'scope': 'user:admin',
}
self.postResponse('web.authorize_application', form=form, with_csrf=True, expected_code=400)
def test_authorize_success(self):
self.login('devtable', 'password')
# Note: Defined in initdb.py
form = {
'client_id': 'deadbeef',
'redirect_uri': 'http://localhost:8000/o2c.html',
'scope': 'user:admin',
}
resp = self.postResponse('web.authorize_application', form=form, with_csrf=True, expected_code=302)
self.assertTrue('access_token=' in resp.headers['Location'])
def test_authorize_nocsrf(self):
self.login('devtable', 'password')
# Note: Defined in initdb.py
form = {
'client_id': 'deadbeef',
'redirect_uri': 'http://localhost:8000/o2c.html',
'scope': 'user:admin',
}
self.postResponse('web.authorize_application', form=form, with_csrf=False, expected_code=403)
def test_authorize_nocsrf_withinvalidheader(self):
# Note: Defined in initdb.py
form = {
'client_id': 'deadbeef',
'redirect_uri': 'http://localhost:8000/o2c.html',
'scope': 'user:admin',
}
headers = dict(authorization='Some random header')
self.postResponse('web.authorize_application', headers=headers, form=form, with_csrf=False, expected_code=401)
def test_authorize_nocsrf_withbadheader(self):
# Note: Defined in initdb.py
form = {
'client_id': 'deadbeef',
'redirect_uri': 'http://localhost:8000/o2c.html',
'scope': 'user:admin',
}
headers = dict(authorization='Basic ' + base64.b64encode('devtable:invalidpassword'))
self.postResponse('web.authorize_application', headers=headers, form=form, with_csrf=False,
expected_code=401)
def test_authorize_nocsrf_correctheader(self):
# Note: Defined in initdb.py
form = {
'client_id': 'deadbeef',
'redirect_uri': 'http://localhost:8000/o2c.html',
'scope': 'user:admin',
}
# Try without the client id being in the whitelist.
headers = dict(authorization='Basic ' + base64.b64encode('devtable:password'))
self.postResponse('web.authorize_application', headers=headers, form=form, with_csrf=False,
expected_code=403)
# Add the client ID to the whitelist and try again.
app.config['DIRECT_OAUTH_CLIENTID_WHITELIST'] = ['deadbeef']
headers = dict(authorization='Basic ' + base64.b64encode('devtable:password'))
resp = self.postResponse('web.authorize_application', headers=headers, form=form, with_csrf=False, expected_code=302)
self.assertTrue('access_token=' in resp.headers['Location'])
def test_authorize_nocsrf_ratelimiting(self):
# Note: Defined in initdb.py
form = {
'client_id': 'deadbeef',
'redirect_uri': 'http://localhost:8000/o2c.html',
'scope': 'user:admin',
}
# Try without the client id being in the whitelist a few times, making sure we eventually get rate limited.
headers = dict(authorization='Basic ' + base64.b64encode('devtable:invalidpassword'))
self.postResponse('web.authorize_application', headers=headers, form=form, with_csrf=False, expected_code=401)
counter = 0
while True:
r = self.postResponse('web.authorize_application', headers=headers, form=form, with_csrf=False, expected_code=None)
self.assertNotEquals(200, r.status_code)
counter = counter + 1
if counter > 5:
self.fail('Exponential backoff did not fire')
if r.status_code == 429:
break
class KeyServerTestCase(EndpointTestCase):
def _get_test_jwt_payload(self):
return {
'iss': 'sample_service',
'aud': keyserver.JWT_AUDIENCE,
'exp': int(time.time()) + 60,
'iat': int(time.time()),
'nbf': int(time.time()),
}
def test_list_service_keys(self):
# Retrieve all the keys.
all_keys = model.service_keys.list_all_keys()
visible_jwks = [jwk_with_kid(key) for key in model.service_keys.list_service_keys('sample_service')]
invisible_jwks = []
for key in all_keys:
is_expired = key.expiration_date and key.expiration_date <= datetime.utcnow()
if key.service != 'sample_service' or key.approval is None or is_expired:
invisible_jwks.append(key.jwk)
rv = self.getResponse('key_server.list_service_keys', service='sample_service')
jwkset = py_json.loads(rv)
# Make sure the hidden keys are not returned and the visible ones are returned.
self.assertTrue(len(visible_jwks) > 0)
self.assertTrue(len(invisible_jwks) > 0)
self.assertEquals(len(visible_jwks), len(jwkset['keys']))
for jwk in jwkset['keys']:
self.assertIn(jwk, visible_jwks)
self.assertNotIn(jwk, invisible_jwks)
def test_get_service_key(self):
# 200 for an approved key
self.getResponse('key_server.get_service_key', service='sample_service', kid='kid1')
# 409 for an unapproved key
self.getResponse('key_server.get_service_key', service='sample_service', kid='kid3',
expected_code=409)
# 404 for a non-existant key
self.getResponse('key_server.get_service_key', service='sample_service', kid='kid9999',
expected_code=404)
# 403 for an approved but expired key that is inside of the 2 week window.
self.getResponse('key_server.get_service_key', service='sample_service', kid='kid6',
expected_code=403)
# 404 for an approved, expired key that is outside of the 2 week window.
self.getResponse('key_server.get_service_key', service='sample_service', kid='kid7',
expected_code=404)
def test_put_service_key(self):
# No Authorization header should yield a 400
self.putResponse('key_server.put_service_key', service='sample_service', kid='kid420',
expected_code=400)
# Mint a JWT with our test payload
private_key = RSA.generate(2048)
jwk = RSAKey(key=private_key.publickey()).serialize()
payload = self._get_test_jwt_payload()
token = jwt.encode(payload, private_key.exportKey('PEM'), 'RS256')
# Invalid service name should yield a 400.
self.putResponse('key_server.put_service_key', service='sample service', kid='kid420',
headers={
'Authorization': 'Bearer %s' % token,
'Content-Type': 'application/json',
}, data=jwk, expected_code=400)
# Publish a new key
with assert_action_logged('service_key_create'):
self.putResponse('key_server.put_service_key', service='sample_service', kid='kid420',
headers={
'Authorization': 'Bearer %s' % token,
'Content-Type': 'application/json',
}, data=jwk, expected_code=202)
# Ensure that the key exists but is unapproved.
self.getResponse('key_server.get_service_key', service='sample_service', kid='kid420',
expected_code=409)
# Attempt to rotate the key. Since not approved, it will fail.
token = jwt.encode(payload, private_key.exportKey('PEM'), 'RS256', headers={'kid': 'kid420'})
self.putResponse('key_server.put_service_key', service='sample_service', kid='kid6969',
headers={
'Authorization': 'Bearer %s' % token,
'Content-Type': 'application/json',
}, data=jwk, expected_code=403)
# Approve the key.
model.service_keys.approve_service_key('kid420', ServiceKeyApprovalType.SUPERUSER, approver=1)
# Rotate that new key
with assert_action_logged('service_key_rotate'):
token = jwt.encode(payload, private_key.exportKey('PEM'), 'RS256', headers={'kid': 'kid420'})
self.putResponse('key_server.put_service_key', service='sample_service', kid='kid6969',
headers={
'Authorization': 'Bearer %s' % token,
'Content-Type': 'application/json',
}, data=jwk, expected_code=200)
# Rotation should only work when signed by the previous key
private_key = RSA.generate(2048)
jwk = RSAKey(key=private_key.publickey()).serialize()
token = jwt.encode(payload, private_key.exportKey('PEM'), 'RS256', headers={'kid': 'kid420'})
self.putResponse('key_server.put_service_key', service='sample_service', kid='kid6969',
headers={
'Authorization': 'Bearer %s' % token,
'Content-Type': 'application/json',
}, data=jwk, expected_code=403)
def test_attempt_delete_service_key_with_no_kid_signer(self):
# Generate two keys, approving the first.
private_key, _ = model.service_keys.generate_service_key('sample_service', None, kid='first')
# Mint a JWT with our test payload but *no kid*.
token = jwt.encode(self._get_test_jwt_payload(), private_key.exportKey('PEM'), 'RS256',
headers={})
# Using the credentials of our key, attempt to delete our unapproved key
self.deleteResponse('key_server.delete_service_key',
headers={'Authorization': 'Bearer %s' % token},
expected_code=400, service='sample_service', kid='first')
def test_attempt_delete_service_key_with_expired_key(self):
# Generate two keys, approving the first.
private_key, _ = model.service_keys.generate_service_key('sample_service', None, kid='first')
model.service_keys.approve_service_key('first', ServiceKeyApprovalType.SUPERUSER, approver=1)
model.service_keys.generate_service_key('sample_service', None, kid='second')
# Mint a JWT with our test payload
token = jwt.encode(self._get_test_jwt_payload(), private_key.exportKey('PEM'), 'RS256',
headers={'kid': 'first'})
# Set the expiration of the first to now - some time.
model.service_keys.set_key_expiration('first', datetime.utcnow() - timedelta(seconds=100))
# Using the credentials of our second key, attempt to delete our unapproved key
self.deleteResponse('key_server.delete_service_key',
headers={'Authorization': 'Bearer %s' % token},
expected_code=403, service='sample_service', kid='second')
# Set the expiration to the future and delete the key.
model.service_keys.set_key_expiration('first', datetime.utcnow() + timedelta(seconds=100))
with assert_action_logged('service_key_delete'):
self.deleteEmptyResponse('key_server.delete_service_key',
headers={'Authorization': 'Bearer %s' % token},
expected_code=204, service='sample_service', kid='second')
def test_delete_unapproved_service_key(self):
# No Authorization header should yield a 400
self.deleteResponse('key_server.delete_service_key', expected_code=400,
service='sample_service', kid='kid1')
# Generate an unapproved key.
private_key, _ = model.service_keys.generate_service_key('sample_service', None,
kid='unapprovedkeyhere')
# Mint a JWT with our test payload
token = jwt.encode(self._get_test_jwt_payload(), private_key.exportKey('PEM'), 'RS256',
headers={'kid': 'unapprovedkeyhere'})
# Delete our unapproved key with itself.
with assert_action_logged('service_key_delete'):
self.deleteEmptyResponse('key_server.delete_service_key',
headers={'Authorization': 'Bearer %s' % token},
expected_code=204, service='sample_service', kid='unapprovedkeyhere')
def test_delete_chained_service_key(self):
# No Authorization header should yield a 400
self.deleteResponse('key_server.delete_service_key', expected_code=400,
service='sample_service', kid='kid1')
# Generate two keys.
private_key, _ = model.service_keys.generate_service_key('sample_service', None, kid='kid123')
model.service_keys.generate_service_key('sample_service', None, kid='kid321')
# Mint a JWT with our test payload
token = jwt.encode(self._get_test_jwt_payload(), private_key.exportKey('PEM'), 'RS256',
headers={'kid': 'kid123'})
# Using the credentials of our second key, attempt tp delete our unapproved key
self.deleteResponse('key_server.delete_service_key',
headers={'Authorization': 'Bearer %s' % token},
expected_code=403, service='sample_service', kid='kid321')
# Approve the second key.
model.service_keys.approve_service_key('kid123', ServiceKeyApprovalType.SUPERUSER, approver=1)
# Using the credentials of our approved key, delete our unapproved key
with assert_action_logged('service_key_delete'):
self.deleteEmptyResponse('key_server.delete_service_key',
headers={'Authorization': 'Bearer %s' % token},
expected_code=204, service='sample_service', kid='kid321')
# Attempt to delete a key signed by a key from a different service
bad_token = jwt.encode(self._get_test_jwt_payload(), private_key.exportKey('PEM'), 'RS256',
headers={'kid': 'kid5'})
self.deleteResponse('key_server.delete_service_key',
headers={'Authorization': 'Bearer %s' % bad_token},
expected_code=403, service='sample_service', kid='kid123')
# Delete a self-signed, approved key
with assert_action_logged('service_key_delete'):
self.deleteEmptyResponse('key_server.delete_service_key',
headers={'Authorization': 'Bearer %s' % token},
expected_code=204, service='sample_service', kid='kid123')
if __name__ == '__main__':
unittest.main()

View file

@ -0,0 +1,327 @@
import base64
import unittest
from datetime import datetime, timedelta
from tempfile import NamedTemporaryFile
from contextlib import contextmanager
import jwt
import requests
from Crypto.PublicKey import RSA
from flask import Flask, jsonify, request, make_response
from app import app
from data.users import ExternalJWTAuthN
from initdb import setup_database_for_testing, finished_database_for_testing
from test.helpers import liveserver_app
_PORT_NUMBER = 5001
@contextmanager
def fake_jwt(requires_email=True):
""" Context manager which instantiates and runs a webserver with a fake JWT implementation,
until the result is yielded.
Usage:
with fake_jwt() as jwt_auth:
# Make jwt_auth requests.
"""
jwt_app, port, public_key = _create_app(requires_email)
server_url = 'http://' + jwt_app.config['SERVER_HOSTNAME']
verify_url = server_url + '/user/verify'
query_url = server_url + '/user/query'
getuser_url = server_url + '/user/get'
jwt_auth = ExternalJWTAuthN(verify_url, query_url, getuser_url, 'authy', '',
app.config['HTTPCLIENT'], 300,
public_key_path=public_key.name,
requires_email=requires_email)
with liveserver_app(jwt_app, port):
yield jwt_auth
def _generate_certs():
public_key = NamedTemporaryFile(delete=True)
key = RSA.generate(1024)
private_key_data = key.exportKey('PEM')
pubkey = key.publickey()
public_key.write(pubkey.exportKey('OpenSSH'))
public_key.seek(0)
return (public_key, private_key_data)
def _create_app(emails=True):
global _PORT_NUMBER
_PORT_NUMBER = _PORT_NUMBER + 1
public_key, private_key_data = _generate_certs()
users = [
{'name': 'cool.user', 'email': 'user@domain.com', 'password': 'password'},
{'name': 'some.neat.user', 'email': 'neat@domain.com', 'password': 'foobar'},
# Feature Flag: Email Blacklisting
{'name': 'blacklistedcom', 'email': 'foo@blacklisted.com', 'password': 'somepass'},
{'name': 'blacklistednet', 'email': 'foo@blacklisted.net', 'password': 'somepass'},
{'name': 'blacklistedorg', 'email': 'foo@blacklisted.org', 'password': 'somepass'},
{'name': 'notblacklistedcom', 'email': 'foo@notblacklisted.com', 'password': 'somepass'},
]
jwt_app = Flask('testjwt')
jwt_app.config['SERVER_HOSTNAME'] = 'localhost:%s' % _PORT_NUMBER
def _get_basic_auth():
data = base64.b64decode(request.headers['Authorization'][len('Basic '):])
return data.split(':', 1)
@jwt_app.route('/user/query', methods=['GET'])
def query_users():
query = request.args.get('query')
results = []
for user in users:
if user['name'].startswith(query):
result = {
'username': user['name'],
}
if emails:
result['email'] = user['email']
results.append(result)
token_data = {
'iss': 'authy',
'aud': 'quay.io/jwtauthn/query',
'nbf': datetime.utcnow(),
'iat': datetime.utcnow(),
'exp': datetime.utcnow() + timedelta(seconds=60),
'results': results,
}
encoded = jwt.encode(token_data, private_key_data, 'RS256')
return jsonify({
'token': encoded
})
@jwt_app.route('/user/get', methods=['GET'])
def get_user():
username = request.args.get('username')
if username == 'disabled':
return make_response('User is currently disabled', 401)
for user in users:
if user['name'] == username or user['email'] == username:
token_data = {
'iss': 'authy',
'aud': 'quay.io/jwtauthn/getuser',
'nbf': datetime.utcnow(),
'iat': datetime.utcnow(),
'exp': datetime.utcnow() + timedelta(seconds=60),
'sub': user['name'],
'email': user['email'],
}
encoded = jwt.encode(token_data, private_key_data, 'RS256')
return jsonify({
'token': encoded
})
return make_response('Invalid username or password', 404)
@jwt_app.route('/user/verify', methods=['GET'])
def verify_user():
username, password = _get_basic_auth()
if username == 'disabled':
return make_response('User is currently disabled', 401)
for user in users:
if user['name'] == username or user['email'] == username:
if password != user['password']:
return make_response('', 404)
token_data = {
'iss': 'authy',
'aud': 'quay.io/jwtauthn',
'nbf': datetime.utcnow(),
'iat': datetime.utcnow(),
'exp': datetime.utcnow() + timedelta(seconds=60),
'sub': user['name'],
'email': user['email'],
}
encoded = jwt.encode(token_data, private_key_data, 'RS256')
return jsonify({
'token': encoded
})
return make_response('Invalid username or password', 404)
jwt_app.config['TESTING'] = True
return jwt_app, _PORT_NUMBER, public_key
class JWTAuthTestMixin:
""" Mixin defining all the JWT auth tests. """
maxDiff = None
@property
def emails(self):
raise NotImplementedError
def setUp(self):
setup_database_for_testing(self)
self.app = app.test_client()
self.ctx = app.test_request_context()
self.ctx.__enter__()
self.session = requests.Session()
def tearDown(self):
finished_database_for_testing(self)
self.ctx.__exit__(True, None, None)
def test_verify_and_link_user(self):
with fake_jwt(self.emails) as jwt_auth:
result, error_message = jwt_auth.verify_and_link_user('invaliduser', 'foobar')
self.assertEquals('Invalid username or password', error_message)
self.assertIsNone(result)
result, _ = jwt_auth.verify_and_link_user('cool.user', 'invalidpassword')
self.assertIsNone(result)
result, _ = jwt_auth.verify_and_link_user('cool.user', 'password')
self.assertIsNotNone(result)
self.assertEquals('cool_user', result.username)
result, _ = jwt_auth.verify_and_link_user('some.neat.user', 'foobar')
self.assertIsNotNone(result)
self.assertEquals('some_neat_user', result.username)
def test_confirm_existing_user(self):
with fake_jwt(self.emails) as jwt_auth:
# Create the users in the DB.
result, _ = jwt_auth.verify_and_link_user('cool.user', 'password')
self.assertIsNotNone(result)
result, _ = jwt_auth.verify_and_link_user('some.neat.user', 'foobar')
self.assertIsNotNone(result)
# Confirm a user with the same internal and external username.
result, _ = jwt_auth.confirm_existing_user('cool_user', 'invalidpassword')
self.assertIsNone(result)
result, _ = jwt_auth.confirm_existing_user('cool_user', 'password')
self.assertIsNotNone(result)
self.assertEquals('cool_user', result.username)
# Fail to confirm the *external* username, which should return nothing.
result, _ = jwt_auth.confirm_existing_user('some.neat.user', 'password')
self.assertIsNone(result)
# Now confirm the internal username.
result, _ = jwt_auth.confirm_existing_user('some_neat_user', 'foobar')
self.assertIsNotNone(result)
self.assertEquals('some_neat_user', result.username)
def test_disabled_user_custom_error(self):
with fake_jwt(self.emails) as jwt_auth:
result, error_message = jwt_auth.verify_and_link_user('disabled', 'password')
self.assertIsNone(result)
self.assertEquals('User is currently disabled', error_message)
def test_query(self):
with fake_jwt(self.emails) as jwt_auth:
# Lookup `cool`.
results, identifier, error_message = jwt_auth.query_users('cool')
self.assertIsNone(error_message)
self.assertEquals('jwtauthn', identifier)
self.assertEquals(1, len(results))
self.assertEquals('cool.user', results[0].username)
self.assertEquals('user@domain.com' if self.emails else None, results[0].email)
# Lookup `some`.
results, identifier, error_message = jwt_auth.query_users('some')
self.assertIsNone(error_message)
self.assertEquals('jwtauthn', identifier)
self.assertEquals(1, len(results))
self.assertEquals('some.neat.user', results[0].username)
self.assertEquals('neat@domain.com' if self.emails else None, results[0].email)
# Lookup `unknown`.
results, identifier, error_message = jwt_auth.query_users('unknown')
self.assertIsNone(error_message)
self.assertEquals('jwtauthn', identifier)
self.assertEquals(0, len(results))
def test_get_user(self):
with fake_jwt(self.emails) as jwt_auth:
# Lookup cool.user.
result, error_message = jwt_auth.get_user('cool.user')
self.assertIsNone(error_message)
self.assertIsNotNone(result)
self.assertEquals('cool.user', result.username)
self.assertEquals('user@domain.com', result.email)
# Lookup some.neat.user.
result, error_message = jwt_auth.get_user('some.neat.user')
self.assertIsNone(error_message)
self.assertIsNotNone(result)
self.assertEquals('some.neat.user', result.username)
self.assertEquals('neat@domain.com', result.email)
# Lookup unknown user.
result, error_message = jwt_auth.get_user('unknownuser')
self.assertIsNone(result)
def test_link_user(self):
with fake_jwt(self.emails) as jwt_auth:
# Link cool.user.
user, error_message = jwt_auth.link_user('cool.user')
self.assertIsNone(error_message)
self.assertIsNotNone(user)
self.assertEquals('cool_user', user.username)
# Link again. Should return the same user record.
user_again, _ = jwt_auth.link_user('cool.user')
self.assertEquals(user_again.id, user.id)
# Confirm cool.user.
result, _ = jwt_auth.confirm_existing_user('cool_user', 'password')
self.assertIsNotNone(result)
self.assertEquals('cool_user', result.username)
def test_link_invalid_user(self):
with fake_jwt(self.emails) as jwt_auth:
user, error_message = jwt_auth.link_user('invaliduser')
self.assertIsNotNone(error_message)
self.assertIsNone(user)
class JWTAuthNoEmailTestCase(JWTAuthTestMixin, unittest.TestCase):
""" Test cases for JWT auth, with emails disabled. """
@property
def emails(self):
return False
class JWTAuthTestCase(JWTAuthTestMixin, unittest.TestCase):
""" Test cases for JWT auth, with emails enabled. """
@property
def emails(self):
return True
if __name__ == '__main__':
unittest.main()

17
test/test_gunicorn_running.sh Executable file
View file

@ -0,0 +1,17 @@
set -e
echo "Registry"
curl --fail http://localhost:8080/v1/_internal_ping
echo ""
echo "Verbs"
curl --fail http://localhost:8080/c1/_internal_ping
echo ""
echo "Security scan"
curl --fail http://localhost:8080/secscan/_internal_ping
echo ""
echo "Web"
curl --fail http://localhost:8080/_internal_ping
echo ""

418
test/test_keystone_auth.py Normal file
View file

@ -0,0 +1,418 @@
import json
import os
import unittest
import requests
from flask import Flask, request, abort, make_response
from contextlib import contextmanager
from test.helpers import liveserver_app
from data.users.keystone import get_keystone_users
from initdb import setup_database_for_testing, finished_database_for_testing
_PORT_NUMBER = 5001
@contextmanager
def fake_keystone(version=3, requires_email=True):
""" Context manager which instantiates and runs a webserver with a fake Keystone implementation,
until the result is yielded.
Usage:
with fake_keystone(version) as keystone_auth:
# Make keystone_auth requests.
"""
keystone_app, port = _create_app(requires_email)
server_url = 'http://' + keystone_app.config['SERVER_HOSTNAME']
endpoint_url = server_url + '/v3'
if version == 2:
endpoint_url = server_url + '/v2.0/auth'
keystone_auth = get_keystone_users(version, endpoint_url,
'adminuser', 'adminpass', 'admintenant',
requires_email=requires_email)
with liveserver_app(keystone_app, port):
yield keystone_auth
def _create_app(requires_email=True):
global _PORT_NUMBER
_PORT_NUMBER = _PORT_NUMBER + 1
server_url = 'http://localhost:%s' % (_PORT_NUMBER)
users = [
{'username': 'adminuser', 'name': 'Admin User', 'password': 'adminpass'},
{'username': 'cool.user', 'name': 'Cool User', 'password': 'password'},
{'username': 'some.neat.user', 'name': 'Neat User', 'password': 'foobar'},
]
# Feature Flag: Email-based Blacklisting
# Create additional, mocked Users
test_domains = ('blacklisted.com', 'blacklisted.net', 'blacklisted.org',
'notblacklisted.com', 'mail.blacklisted.com')
for domain in test_domains:
mock_email = 'foo@' + domain # e.g. foo@blacklisted.com
new_user = {
'username': mock_email, # Simplifies consistent querying in tests
'name': domain.replace('.', ''), # blacklisted.com => blacklistedcom
'email': mock_email,
'password': 'somepass'
}
users.append(new_user)
groups = [
{'id': 'somegroupid', 'name': 'somegroup', 'description': 'Hi there!',
'members': ['adminuser', 'cool.user']},
{'id': 'admintenant', 'name': 'somegroup', 'description': 'Hi there!',
'members': ['adminuser', 'cool.user']},
]
def _get_user(username):
for user in users:
if user['username'] == username:
user_data = {}
user_data['id'] = username
user_data['name'] = username
if requires_email:
user_data['email'] = user.get('email') or username + '@example.com'
return user_data
return None
ks_app = Flask('testks')
ks_app.config['SERVER_HOSTNAME'] = 'localhost:%s' % _PORT_NUMBER
if os.environ.get('DEBUG') == 'true':
ks_app.config['DEBUG'] = True
@ks_app.route('/v2.0/admin/users/<userid>', methods=['GET'])
def getuser(userid):
for user in users:
if user['username'] == userid:
user_data = {}
if requires_email:
user_data['email'] = user.get('email') or userid + '@example.com'
return json.dumps({
'user': user_data
})
abort(404)
# v2 referred to all groups as tenants, so replace occurrences of 'group' with 'tenant'
@ks_app.route('/v2.0/admin/tenants/<tenant>/users', methods=['GET'])
def getv2_tenant_members(tenant):
return getv3groupmembers(tenant)
@ks_app.route('/v3/identity/groups/<groupid>/users', methods=['GET'])
def getv3groupmembers(groupid):
for group in groups:
if group['id'] == groupid:
group_data = {
"links": {},
"users": [_get_user(username) for username in group['members']],
}
return json.dumps(group_data)
abort(404)
@ks_app.route('/v3/identity/groups/<groupid>', methods=['GET'])
def getv3group(groupid):
for group in groups:
if group['id'] == groupid:
group_data = {
"description": group['description'],
"domain_id": "default",
"id": groupid,
"links": {},
"name": group['name'],
}
return json.dumps({'group': group_data})
abort(404)
@ks_app.route('/v3/identity/users/<userid>', methods=['GET'])
def getv3user(userid):
for user in users:
if user['username'] == userid:
user_data = {
"domain_id": "default",
"enabled": True,
"id": user['username'],
"links": {},
"name": user['username'],
}
if requires_email:
user_data['email'] = user.get('email') or user['username'] + '@example.com'
return json.dumps({
'user': user_data
})
abort(404)
@ks_app.route('/v3/identity/users', methods=['GET'])
def v3identity():
returned = []
for user in users:
if not request.args.get('name') or user['username'].startswith(request.args.get('name')):
returned.append({
"domain_id": "default",
"enabled": True,
"id": user['username'],
"links": {},
"name": user['username'],
"email": user.get('email') or user['username'] + '@example.com',
})
return json.dumps({"users": returned})
@ks_app.route('/v3/auth/tokens', methods=['POST'])
def v3tokens():
creds = request.json['auth']['identity']['password']['user']
for user in users:
if creds['name'] == user['username'] and creds['password'] == user['password']:
data = json.dumps({
"token": {
"methods": [
"password"
],
"roles": [
{
"id": "9fe2ff9ee4384b1894a90878d3e92bab",
"name": "_member_"
},
{
"id": "c703057be878458588961ce9a0ce686b",
"name": "admin"
}
],
"project": {
"domain": {
"id": "default",
"name": "Default"
},
"id": "8538a3f13f9541b28c2620eb19065e45",
"name": "admin"
},
"catalog": [
{
"endpoints": [
{
"url": server_url + '/v3/identity',
"region": "RegionOne",
"interface": "admin",
"id": "29beb2f1567642eb810b042b6719ea88"
},
],
"type": "identity",
"id": "bd73972c0e14fb69bae8ff76e112a90",
"name": "keystone"
}
],
"extras": {
},
"user": {
"domain": {
"id": "default",
"name": "Default"
},
"id": user['username'],
"name": "admin"
},
"audit_ids": [
"yRt0UrxJSs6-WYJgwEMMmg"
],
"issued_at": "2014-06-16T22:24:26.089380",
"expires_at": "2020-06-16T23:24:26Z",
}
})
response = make_response(data, 200)
response.headers['X-Subject-Token'] = 'sometoken'
return response
abort(403)
@ks_app.route('/v2.0/auth/tokens', methods=['POST'])
def tokens():
creds = request.json['auth'][u'passwordCredentials']
for user in users:
if creds['username'] == user['username'] and creds['password'] == user['password']:
return json.dumps({
"access": {
"token": {
"issued_at": "2014-06-16T22:24:26.089380",
"expires": "2020-06-16T23:24:26Z",
"id": creds['username'],
"tenant": {"id": "sometenant"},
},
"serviceCatalog":[
{
"endpoints": [
{
"adminURL": server_url + '/v2.0/admin',
}
],
"endpoints_links": [],
"type": "identity",
"name": "admin",
},
],
"user": {
"username": creds['username'],
"roles_links": [],
"id": creds['username'],
"roles": [],
"name": user['name'],
},
"metadata": {
"is_admin": 0,
"roles": [],
},
},
})
abort(403)
return ks_app, _PORT_NUMBER
class KeystoneAuthTestsMixin:
maxDiff = None
@property
def emails(self):
raise NotImplementedError
def fake_keystone(self):
raise NotImplementedError
def setUp(self):
setup_database_for_testing(self)
self.session = requests.Session()
def tearDown(self):
finished_database_for_testing(self)
def test_invalid_user(self):
with self.fake_keystone() as keystone:
(user, _) = keystone.verify_credentials('unknownuser', 'password')
self.assertIsNone(user)
def test_invalid_password(self):
with self.fake_keystone() as keystone:
(user, _) = keystone.verify_credentials('cool.user', 'notpassword')
self.assertIsNone(user)
def test_cooluser(self):
with self.fake_keystone() as keystone:
(user, _) = keystone.verify_credentials('cool.user', 'password')
self.assertEquals(user.username, 'cool.user')
self.assertEquals(user.email, 'cool.user@example.com' if self.emails else None)
def test_neatuser(self):
with self.fake_keystone() as keystone:
(user, _) = keystone.verify_credentials('some.neat.user', 'foobar')
self.assertEquals(user.username, 'some.neat.user')
self.assertEquals(user.email, 'some.neat.user@example.com' if self.emails else None)
class KeystoneV2AuthNoEmailTests(KeystoneAuthTestsMixin, unittest.TestCase):
def fake_keystone(self):
return fake_keystone(2, requires_email=False)
@property
def emails(self):
return False
class KeystoneV3AuthNoEmailTests(KeystoneAuthTestsMixin, unittest.TestCase):
def fake_keystone(self):
return fake_keystone(3, requires_email=False)
@property
def emails(self):
return False
class KeystoneV2AuthTests(KeystoneAuthTestsMixin, unittest.TestCase):
def fake_keystone(self):
return fake_keystone(2, requires_email=True)
@property
def emails(self):
return True
class KeystoneV3AuthTests(KeystoneAuthTestsMixin, unittest.TestCase):
def fake_keystone(self):
return fake_keystone(3, requires_email=True)
def emails(self):
return True
def test_query(self):
with self.fake_keystone() as keystone:
# Lookup cool.
(response, federated_id, error_message) = keystone.query_users('cool')
self.assertIsNone(error_message)
self.assertEquals(1, len(response))
self.assertEquals('keystone', federated_id)
user_info = response[0]
self.assertEquals("cool.user", user_info.username)
# Lookup unknown.
(response, federated_id, error_message) = keystone.query_users('unknown')
self.assertIsNone(error_message)
self.assertEquals(0, len(response))
self.assertEquals('keystone', federated_id)
def test_link_user(self):
with self.fake_keystone() as keystone:
# Link someuser.
user, error_message = keystone.link_user('cool.user')
self.assertIsNone(error_message)
self.assertIsNotNone(user)
self.assertEquals('cool_user', user.username)
self.assertEquals('cool.user@example.com', user.email)
# Link again. Should return the same user record.
user_again, _ = keystone.link_user('cool.user')
self.assertEquals(user_again.id, user.id)
# Confirm someuser.
result, _ = keystone.confirm_existing_user('cool_user', 'password')
self.assertIsNotNone(result)
self.assertEquals('cool_user', result.username)
def test_check_group_lookup_args(self):
with self.fake_keystone() as keystone:
(status, err) = keystone.check_group_lookup_args({})
self.assertFalse(status)
self.assertEquals('Missing group_id', err)
(status, err) = keystone.check_group_lookup_args({'group_id': 'unknownid'})
self.assertFalse(status)
self.assertEquals('Group not found', err)
(status, err) = keystone.check_group_lookup_args({'group_id': 'somegroupid'})
self.assertTrue(status)
self.assertIsNone(err)
def test_iterate_group_members(self):
with self.fake_keystone() as keystone:
(itt, err) = keystone.iterate_group_members({'group_id': 'somegroupid'})
self.assertIsNone(err)
results = list(itt)
results.sort()
self.assertEquals(2, len(results))
self.assertEquals('adminuser', results[0][0].id)
self.assertEquals('cool.user', results[1][0].id)
if __name__ == '__main__':
unittest.main()

582
test/test_ldap.py Normal file
View file

@ -0,0 +1,582 @@
import unittest
import ldap
from app import app
from initdb import setup_database_for_testing, finished_database_for_testing
from data.users import LDAPUsers
from data import model
from mockldap import MockLdap
from mock import patch
from contextlib import contextmanager
def _create_ldap(requires_email=True):
base_dn = ['dc=quay', 'dc=io']
admin_dn = 'uid=testy,ou=employees,dc=quay,dc=io'
admin_passwd = 'password'
user_rdn = ['ou=employees']
uid_attr = 'uid'
email_attr = 'mail'
secondary_user_rdns = ['ou=otheremployees']
ldap = LDAPUsers('ldap://localhost', base_dn, admin_dn, admin_passwd, user_rdn,
uid_attr, email_attr, secondary_user_rdns=secondary_user_rdns,
requires_email=requires_email)
return ldap
@contextmanager
def mock_ldap(requires_email=True):
mock_data = {
'dc=quay,dc=io': {'dc': ['quay', 'io']},
'ou=employees,dc=quay,dc=io': {
'dc': ['quay', 'io'],
'ou': 'employees'
},
'ou=otheremployees,dc=quay,dc=io': {
'dc': ['quay', 'io'],
'ou': 'otheremployees'
},
'cn=AwesomeFolk,dc=quay,dc=io': {
'dc': ['quay', 'io'],
'cn': 'AwesomeFolk'
},
'uid=testy,ou=employees,dc=quay,dc=io': {
'dc': ['quay', 'io'],
'ou': 'employees',
'uid': ['testy'],
'userPassword': ['password'],
'mail': ['bar@baz.com'],
'memberOf': ['cn=AwesomeFolk,dc=quay,dc=io', 'cn=*Guys,dc=quay,dc=io'],
},
'uid=someuser,ou=employees,dc=quay,dc=io': {
'dc': ['quay', 'io'],
'ou': 'employees',
'uid': ['someuser'],
'userPassword': ['somepass'],
'mail': ['foo@bar.com'],
'memberOf': ['cn=AwesomeFolk,dc=quay,dc=io', 'cn=*Guys,dc=quay,dc=io'],
},
'uid=nomail,ou=employees,dc=quay,dc=io': {
'dc': ['quay', 'io'],
'ou': 'employees',
'uid': ['nomail'],
'userPassword': ['somepass']
},
'uid=cool.user,ou=employees,dc=quay,dc=io': {
'dc': ['quay', 'io'],
'ou': 'employees',
'uid': ['cool.user', 'referred'],
'userPassword': ['somepass'],
'mail': ['foo@bar.com']
},
'uid=referred,ou=employees,dc=quay,dc=io': {
'uid': ['referred'],
'_referral': 'ldap:///uid=cool.user,ou=employees,dc=quay,dc=io'
},
'uid=invalidreferred,ou=employees,dc=quay,dc=io': {
'uid': ['invalidreferred'],
'_referral': 'ldap:///uid=someinvaliduser,ou=employees,dc=quay,dc=io'
},
'uid=multientry,ou=subgroup1,ou=employees,dc=quay,dc=io': {
'uid': ['multientry'],
'mail': ['foo@bar.com'],
'userPassword': ['somepass'],
},
'uid=multientry,ou=subgroup2,ou=employees,dc=quay,dc=io': {
'uid': ['multientry'],
'another': ['key']
},
'uid=secondaryuser,ou=otheremployees,dc=quay,dc=io': {
'dc': ['quay', 'io'],
'ou': 'otheremployees',
'uid': ['secondaryuser'],
'userPassword': ['somepass'],
'mail': ['foosecondary@bar.com']
},
# Feature: Email Blacklisting
'uid=blacklistedcom,ou=otheremployees,dc=quay,dc=io': {
'dc': ['quay', 'io'],
'ou': 'otheremployees',
'uid': ['blacklistedcom'],
'userPassword': ['somepass'],
'mail': ['foo@blacklisted.com']
},
'uid=blacklistednet,ou=otheremployees,dc=quay,dc=io': {
'dc': ['quay', 'io'],
'ou': 'otheremployees',
'uid': ['blacklistednet'],
'userPassword': ['somepass'],
'mail': ['foo@blacklisted.net']
},
'uid=blacklistedorg,ou=otheremployees,dc=quay,dc=io': {
'dc': ['quay', 'io'],
'ou': 'otheremployees',
'uid': ['blacklistedorg'],
'userPassword': ['somepass'],
'mail': ['foo@blacklisted.org']
},
'uid=notblacklistedcom,ou=otheremployees,dc=quay,dc=io': {
'dc': ['quay', 'io'],
'ou': 'otheremployees',
'uid': ['notblacklistedcom'],
'userPassword': ['somepass'],
'mail': ['foo@notblacklisted.com']
},
}
if not requires_email:
for path in mock_data:
mock_data[path].pop('mail', None)
mockldap = MockLdap(mock_data)
def initializer(uri, trace_level=0):
obj = mockldap[uri]
# Seed to "support" wildcard queries, which MockLDAP does not support natively.
cool_block = {
'dc': ['quay', 'io'],
'ou': 'employees',
'uid': ['cool.user', 'referred'],
'userPassword': ['somepass'],
'mail': ['foo@bar.com']
}
if not requires_email:
cool_block.pop('mail', None)
obj.search_s.seed('ou=employees,dc=quay,dc=io', 2, '(|(uid=cool*)(mail=cool*))')([
('uid=cool.user,ou=employees,dc=quay,dc=io', cool_block)
])
obj.search_s.seed('ou=otheremployees,dc=quay,dc=io', 2, '(|(uid=cool*)(mail=cool*))')([])
obj.search_s.seed('ou=employees,dc=quay,dc=io', 2, '(|(uid=unknown*)(mail=unknown*))')([])
obj.search_s.seed('ou=otheremployees,dc=quay,dc=io', 2,
'(|(uid=unknown*)(mail=unknown*))')([])
no_users_found_exception = Exception()
no_users_found_exception.message = { 'matched': 'dc=quay,dc=io', 'desc': 'No such object' }
obj.search_s.seed('ou=nonexistent,dc=quay,dc=io', 2)(no_users_found_exception)
obj.search_s.seed('ou=employees,dc=quay,dc=io', 2)([
('uid=cool.user,ou=employees,dc=quay,dc=io', cool_block)
])
obj.search.seed('ou=employees,dc=quay,dc=io', 2, '(objectClass=*)')([
('uid=cool.user,ou=employees,dc=quay,dc=io', cool_block)
])
obj.search.seed('ou=employees,dc=quay,dc=io', 2)([
('uid=cool.user,ou=employees,dc=quay,dc=io', cool_block)
])
obj._results = {}
original_result_fn = obj.result
def result(messageid):
if messageid is None:
return None, [], None, None
# NOTE: Added because of weirdness with using mock-ldap.
if isinstance(messageid, list):
return None, messageid
if messageid in obj._results:
return obj._results[messageid]
return original_result_fn(messageid)
def result3(messageid):
if messageid is None:
return None, [], None, None
return obj._results[messageid]
def search_ext(user_search_dn, scope, search_flt=None, serverctrls=None,
sizelimit=None, attrlist=None):
if scope != ldap.SCOPE_SUBTREE:
return None
if not serverctrls:
if search_flt:
rdata = obj.search_s(user_search_dn, scope, search_flt, attrlist=attrlist)
else:
if attrlist:
rdata = obj.search_s(user_search_dn, scope, attrlist=attrlist)
else:
rdata = obj.search_s(user_search_dn, scope)
obj._results['messageid'] = (None, rdata)
return 'messageid'
page_control = serverctrls[0]
if page_control.controlType != ldap.controls.SimplePagedResultsControl.controlType:
return None
if search_flt:
msgid = obj.search(user_search_dn, scope, search_flt, attrlist=attrlist)
else:
if attrlist:
msgid = obj.search(user_search_dn, scope, attrlist=attrlist)
else:
msgid = obj.search(user_search_dn, scope)
_, rdata = obj.result(msgid)
msgid = 'messageid'
cookie = int(page_control.cookie) if page_control.cookie else 0
results = rdata[cookie:cookie+page_control.size]
cookie = cookie + page_control.size
if cookie > len(results):
page_control.cookie = None
else:
page_control.cookie = cookie
obj._results['messageid'] = (None, results, None, [page_control])
return msgid
def search_ext_s(user_search_dn, scope, sizelimit=None):
return [obj.search_s(user_search_dn, scope)]
obj.search_ext = search_ext
obj.result = result
obj.result3 = result3
obj.search_ext_s = search_ext_s
return obj
mockldap.start()
with patch('ldap.initialize', new=initializer):
yield _create_ldap(requires_email=requires_email)
mockldap.stop()
class TestLDAP(unittest.TestCase):
def setUp(self):
setup_database_for_testing(self)
self.app = app.test_client()
self.ctx = app.test_request_context()
self.ctx.__enter__()
def tearDown(self):
finished_database_for_testing(self)
self.ctx.__exit__(True, None, None)
def test_invalid_admin_password(self):
base_dn = ['dc=quay', 'dc=io']
admin_dn = 'uid=testy,ou=employees,dc=quay,dc=io'
admin_passwd = 'INVALIDPASSWORD'
user_rdn = ['ou=employees']
uid_attr = 'uid'
email_attr = 'mail'
with mock_ldap():
ldap = LDAPUsers('ldap://localhost', base_dn, admin_dn, admin_passwd, user_rdn,
uid_attr, email_attr)
# Try to login.
(response, err_msg) = ldap.verify_and_link_user('someuser', 'somepass')
self.assertIsNone(response)
self.assertEquals('LDAP Admin dn or password is invalid', err_msg)
def test_login(self):
with mock_ldap() as ldap:
# Verify we can login.
(response, _) = ldap.verify_and_link_user('someuser', 'somepass')
self.assertEquals(response.username, 'someuser')
self.assertTrue(model.user.has_user_prompt(response, 'confirm_username'))
# Verify we can confirm the user.
(response, _) = ldap.confirm_existing_user('someuser', 'somepass')
self.assertEquals(response.username, 'someuser')
def test_login_empty_password(self):
with mock_ldap() as ldap:
# Verify we cannot login.
(response, err_msg) = ldap.verify_and_link_user('someuser', '')
self.assertIsNone(response)
self.assertEquals(err_msg, 'Anonymous binding not allowed')
# Verify we cannot confirm the user.
(response, err_msg) = ldap.confirm_existing_user('someuser', '')
self.assertIsNone(response)
self.assertEquals(err_msg, 'Invalid user')
def test_login_whitespace_password(self):
with mock_ldap() as ldap:
# Verify we cannot login.
(response, err_msg) = ldap.verify_and_link_user('someuser', ' ')
self.assertIsNone(response)
self.assertEquals(err_msg, 'Invalid password')
# Verify we cannot confirm the user.
(response, err_msg) = ldap.confirm_existing_user('someuser', ' ')
self.assertIsNone(response)
self.assertEquals(err_msg, 'Invalid user')
def test_login_secondary(self):
with mock_ldap() as ldap:
# Verify we can login.
(response, _) = ldap.verify_and_link_user('secondaryuser', 'somepass')
self.assertEquals(response.username, 'secondaryuser')
# Verify we can confirm the user.
(response, _) = ldap.confirm_existing_user('secondaryuser', 'somepass')
self.assertEquals(response.username, 'secondaryuser')
def test_invalid_wildcard(self):
with mock_ldap() as ldap:
# Verify we cannot login with a wildcard.
(response, err_msg) = ldap.verify_and_link_user('some*', 'somepass')
self.assertIsNone(response)
self.assertEquals(err_msg, 'Username not found')
# Verify we cannot confirm the user.
(response, err_msg) = ldap.confirm_existing_user('some*', 'somepass')
self.assertIsNone(response)
self.assertEquals(err_msg, 'Invalid user')
def test_invalid_password(self):
with mock_ldap() as ldap:
# Verify we cannot login with an invalid password.
(response, err_msg) = ldap.verify_and_link_user('someuser', 'invalidpass')
self.assertIsNone(response)
self.assertEquals(err_msg, 'Invalid password')
# Verify we cannot confirm the user.
(response, err_msg) = ldap.confirm_existing_user('someuser', 'invalidpass')
self.assertIsNone(response)
self.assertEquals(err_msg, 'Invalid user')
def test_missing_mail(self):
with mock_ldap() as ldap:
(response, err_msg) = ldap.get_user('nomail')
self.assertIsNone(response)
self.assertEquals('Missing mail field "mail" in user record', err_msg)
def test_missing_mail_allowed(self):
with mock_ldap(requires_email=False) as ldap:
(response, _) = ldap.get_user('nomail')
self.assertEquals(response.username, 'nomail')
def test_confirm_different_username(self):
with mock_ldap() as ldap:
# Verify that the user is logged in and their username was adjusted.
(response, _) = ldap.verify_and_link_user('cool.user', 'somepass')
self.assertEquals(response.username, 'cool_user')
# Verify we can confirm the user's quay username.
(response, _) = ldap.confirm_existing_user('cool_user', 'somepass')
self.assertEquals(response.username, 'cool_user')
# Verify that we *cannot* confirm the LDAP username.
(response, _) = ldap.confirm_existing_user('cool.user', 'somepass')
self.assertIsNone(response)
def test_referral(self):
with mock_ldap() as ldap:
(response, _) = ldap.verify_and_link_user('referred', 'somepass')
self.assertEquals(response.username, 'cool_user')
# Verify we can confirm the user's quay username.
(response, _) = ldap.confirm_existing_user('cool_user', 'somepass')
self.assertEquals(response.username, 'cool_user')
def test_invalid_referral(self):
with mock_ldap() as ldap:
(response, _) = ldap.verify_and_link_user('invalidreferred', 'somepass')
self.assertIsNone(response)
def test_multientry(self):
with mock_ldap() as ldap:
(response, _) = ldap.verify_and_link_user('multientry', 'somepass')
self.assertEquals(response.username, 'multientry')
def test_login_empty_userdn(self):
with mock_ldap():
base_dn = ['ou=employees', 'dc=quay', 'dc=io']
admin_dn = 'uid=testy,ou=employees,dc=quay,dc=io'
admin_passwd = 'password'
user_rdn = []
uid_attr = 'uid'
email_attr = 'mail'
secondary_user_rdns = ['ou=otheremployees']
ldap = LDAPUsers('ldap://localhost', base_dn, admin_dn, admin_passwd, user_rdn,
uid_attr, email_attr, secondary_user_rdns=secondary_user_rdns)
# Verify we can login.
(response, _) = ldap.verify_and_link_user('someuser', 'somepass')
self.assertEquals(response.username, 'someuser')
# Verify we can confirm the user.
(response, _) = ldap.confirm_existing_user('someuser', 'somepass')
self.assertEquals(response.username, 'someuser')
def test_link_user(self):
with mock_ldap() as ldap:
# Link someuser.
user, error_message = ldap.link_user('someuser')
self.assertIsNone(error_message)
self.assertIsNotNone(user)
self.assertEquals('someuser', user.username)
# Link again. Should return the same user record.
user_again, _ = ldap.link_user('someuser')
self.assertEquals(user_again.id, user.id)
# Confirm someuser.
result, _ = ldap.confirm_existing_user('someuser', 'somepass')
self.assertIsNotNone(result)
self.assertEquals('someuser', result.username)
self.assertTrue(model.user.has_user_prompt(user, 'confirm_username'))
def test_query(self):
with mock_ldap() as ldap:
# Lookup cool.
(response, federated_id, error_message) = ldap.query_users('cool')
self.assertIsNone(error_message)
self.assertEquals(1, len(response))
self.assertEquals('ldap', federated_id)
user_info = response[0]
self.assertEquals("cool.user", user_info.username)
self.assertEquals("foo@bar.com", user_info.email)
# Lookup unknown.
(response, federated_id, error_message) = ldap.query_users('unknown')
self.assertIsNone(error_message)
self.assertEquals(0, len(response))
self.assertEquals('ldap', federated_id)
def test_timeout(self):
base_dn = ['dc=quay', 'dc=io']
admin_dn = 'uid=testy,ou=employees,dc=quay,dc=io'
admin_passwd = 'password'
user_rdn = ['ou=employees']
uid_attr = 'uid'
email_attr = 'mail'
secondary_user_rdns = ['ou=otheremployees']
with self.assertRaisesRegexp(Exception, "Can't contact LDAP server"):
ldap = LDAPUsers('ldap://localhost', base_dn, admin_dn, admin_passwd, user_rdn,
uid_attr, email_attr, secondary_user_rdns=secondary_user_rdns,
requires_email=False, timeout=5)
ldap.query_users('cool')
def test_iterate_group_members(self):
with mock_ldap() as ldap:
(it, err) = ldap.iterate_group_members({'group_dn': 'cn=AwesomeFolk'},
disable_pagination=True)
self.assertIsNone(err)
results = list(it)
self.assertEquals(2, len(results))
first = results[0][0]
second = results[1][0]
if first.id == 'testy':
testy, someuser = first, second
else:
testy, someuser = second, first
self.assertEquals('testy', testy.id)
self.assertEquals('testy', testy.username)
self.assertEquals('bar@baz.com', testy.email)
self.assertEquals('someuser', someuser.id)
self.assertEquals('someuser', someuser.username)
self.assertEquals('foo@bar.com', someuser.email)
def test_iterate_group_members_with_pagination(self):
with mock_ldap() as ldap:
for dn in ['cn=AwesomeFolk', 'cn=*Guys']:
(it, err) = ldap.iterate_group_members({'group_dn': dn}, page_size=1)
self.assertIsNone(err)
results = list(it)
self.assertEquals(2, len(results))
first = results[0][0]
second = results[1][0]
if first.id == 'testy':
testy, someuser = first, second
else:
testy, someuser = second, first
self.assertEquals('testy', testy.id)
self.assertEquals('testy', testy.username)
self.assertEquals('bar@baz.com', testy.email)
self.assertEquals('someuser', someuser.id)
self.assertEquals('someuser', someuser.username)
self.assertEquals('foo@bar.com', someuser.email)
def test_check_group_lookup_args(self):
with mock_ldap() as ldap:
(result, err) = ldap.check_group_lookup_args({'group_dn': 'cn=invalid'},
disable_pagination=True)
self.assertFalse(result)
self.assertIsNotNone(err)
(result, err) = ldap.check_group_lookup_args({'group_dn': 'cn=AwesomeFolk'},
disable_pagination=True)
self.assertTrue(result)
self.assertIsNone(err)
(result, err) = ldap.check_group_lookup_args({'group_dn': 'cn=*Guys'},
disable_pagination=True)
self.assertTrue(result)
self.assertIsNone(err)
def test_metadata(self):
with mock_ldap() as ldap:
assert 'base_dn' in ldap.service_metadata()
def test_at_least_one_user_exists_invalid_creds(self):
base_dn = ['dc=quay', 'dc=io']
admin_dn = 'uid=testy,ou=employees,dc=quay,dc=io'
admin_passwd = 'INVALIDPASSWORD'
user_rdn = ['ou=employees']
uid_attr = 'uid'
email_attr = 'mail'
with mock_ldap():
ldap = LDAPUsers('ldap://localhost', base_dn, admin_dn, admin_passwd, user_rdn,
uid_attr, email_attr)
# Try to query with invalid credentials.
(response, err_msg) = ldap.at_least_one_user_exists()
self.assertFalse(response)
self.assertEquals('LDAP Admin dn or password is invalid', err_msg)
def test_at_least_one_user_exists_no_users(self):
base_dn = ['dc=quay', 'dc=io']
admin_dn = 'uid=testy,ou=employees,dc=quay,dc=io'
admin_passwd = 'password'
user_rdn = ['ou=nonexistent']
uid_attr = 'uid'
email_attr = 'mail'
with mock_ldap():
ldap = LDAPUsers('ldap://localhost', base_dn, admin_dn, admin_passwd, user_rdn,
uid_attr, email_attr)
# Try to find users in a nonexistent group.
(response, err_msg) = ldap.at_least_one_user_exists()
self.assertFalse(response)
assert err_msg is not None
def test_at_least_one_user_exists_true(self):
with mock_ldap() as ldap:
# Ensure we have at least a single user in the valid group
(response, err_msg) = ldap.at_least_one_user_exists()
self.assertIsNone(err_msg)
self.assertTrue(response)
if __name__ == '__main__':
unittest.main()

209
test/test_oauth_login.py Normal file
View file

@ -0,0 +1,209 @@
import json as py_json
import time
import unittest
import urlparse
import jwt
from Crypto.PublicKey import RSA
from httmock import urlmatch, HTTMock
from jwkest.jwk import RSAKey
from app import app, authentication
from data import model
from endpoints.oauth.login import oauthlogin as oauthlogin_bp
from test.test_endpoints import EndpointTestCase
from test.test_ldap import mock_ldap
class AuthForTesting(object):
def __init__(self, auth_engine):
self.auth_engine = auth_engine
self.existing_state = None
def __enter__(self):
self.existing_state = authentication.state
authentication.state = self.auth_engine
def __exit__(self, type, value, traceback):
authentication.state = self.existing_state
try:
app.register_blueprint(oauthlogin_bp, url_prefix='/oauth2')
except ValueError:
# This blueprint was already registered
pass
class OAuthLoginTestCase(EndpointTestCase):
def invoke_oauth_tests(self, callback_endpoint, attach_endpoint, service_name, service_ident,
new_username, test_attach=True):
# Test callback.
created = self.invoke_oauth_test(callback_endpoint, service_name, service_ident, new_username)
# Delete the created user.
self.assertNotEquals(created.username, 'devtable')
model.user.delete_user(created, [])
# Test attach.
if test_attach:
self.login('devtable', 'password')
self.invoke_oauth_test(attach_endpoint, service_name, service_ident, 'devtable')
def invoke_oauth_test(self, endpoint_name, service_name, service_ident, username):
self._set_csrf()
# No CSRF.
self.getResponse('oauthlogin.' + endpoint_name, expected_code=403)
# Invalid CSRF.
self.getResponse('oauthlogin.' + endpoint_name, state='somestate', expected_code=403)
# Valid CSRF, invalid code.
self.getResponse('oauthlogin.' + endpoint_name, state='someoauthtoken',
code='invalidcode', expected_code=400)
# Valid CSRF, valid code.
self.getResponse('oauthlogin.' + endpoint_name, state='someoauthtoken',
code='somecode', expected_code=302)
# Ensure the user was added/modified.
found_user = model.user.get_user(username)
self.assertIsNotNone(found_user)
federated_login = model.user.lookup_federated_login(found_user, service_name)
self.assertIsNotNone(federated_login)
self.assertEquals(federated_login.service_ident, service_ident)
return found_user
def test_google_oauth(self):
@urlmatch(netloc=r'accounts.google.com', path='/o/oauth2/token')
def account_handler(_, request):
parsed = dict(urlparse.parse_qsl(request.body))
if parsed['code'] == 'somecode':
content = {'access_token': 'someaccesstoken'}
return py_json.dumps(content)
else:
return {'status_code': 400, 'content': '{"message": "Invalid code"}'}
@urlmatch(netloc=r'www.googleapis.com', path='/oauth2/v1/userinfo')
def user_handler(_, __):
content = {
'id': 'someid',
'email': 'someemail@example.com',
'verified_email': True,
}
return py_json.dumps(content)
with HTTMock(account_handler, user_handler):
self.invoke_oauth_tests('google_oauth_callback', 'google_oauth_attach', 'google',
'someid', 'someemail')
def test_github_oauth(self):
@urlmatch(netloc=r'github.com', path='/login/oauth/access_token')
def account_handler(url, _):
parsed = dict(urlparse.parse_qsl(url.query))
if parsed['code'] == 'somecode':
content = {'access_token': 'someaccesstoken'}
return py_json.dumps(content)
else:
return {'status_code': 400, 'content': '{"message": "Invalid code"}'}
@urlmatch(netloc=r'github.com', path='/api/v3/user')
def user_handler(_, __):
content = {
'id': 'someid',
'login': 'someusername'
}
return py_json.dumps(content)
@urlmatch(netloc=r'github.com', path='/api/v3/user/emails')
def email_handler(_, __):
content = [{
'email': 'someemail@example.com',
'verified': True,
'primary': True,
}]
return py_json.dumps(content)
with HTTMock(account_handler, email_handler, user_handler):
self.invoke_oauth_tests('github_oauth_callback', 'github_oauth_attach', 'github',
'someid', 'someusername')
def _get_oidc_mocks(self):
private_key = RSA.generate(2048)
generatedjwk = RSAKey(key=private_key.publickey()).serialize()
kid = 'somekey'
private_pem = private_key.exportKey('PEM')
token_data = {
'iss': app.config['TESTOIDC_LOGIN_CONFIG']['OIDC_SERVER'],
'aud': app.config['TESTOIDC_LOGIN_CONFIG']['CLIENT_ID'],
'nbf': int(time.time()),
'iat': int(time.time()),
'exp': int(time.time() + 600),
'sub': 'cool.user',
}
token_headers = {
'kid': kid,
}
id_token = jwt.encode(token_data, private_pem, 'RS256', headers=token_headers)
@urlmatch(netloc=r'fakeoidc', path='/token')
def token_handler(_, request):
if request.body.find("code=somecode") >= 0:
content = {'access_token': 'someaccesstoken', 'id_token': id_token}
return py_json.dumps(content)
else:
return {'status_code': 400, 'content': '{"message": "Invalid code"}'}
@urlmatch(netloc=r'fakeoidc', path='/user')
def user_handler(_, __):
content = {
'sub': 'cool.user',
'preferred_username': 'someusername',
'email': 'someemail@example.com',
'email_verified': True,
}
return py_json.dumps(content)
@urlmatch(netloc=r'fakeoidc', path='/jwks')
def jwks_handler(_, __):
jwk = generatedjwk.copy()
jwk.update({'kid': kid})
content = {'keys': [jwk]}
return py_json.dumps(content)
@urlmatch(netloc=r'fakeoidc', path='.+openid.+')
def discovery_handler(_, __):
content = {
'scopes_supported': ['profile'],
'authorization_endpoint': 'http://fakeoidc/authorize',
'token_endpoint': 'http://fakeoidc/token',
'userinfo_endpoint': 'http://fakeoidc/userinfo',
'jwks_uri': 'http://fakeoidc/jwks',
}
return py_json.dumps(content)
return (discovery_handler, jwks_handler, token_handler, user_handler)
def test_oidc_database_auth(self):
oidc_mocks = self._get_oidc_mocks()
with HTTMock(*oidc_mocks):
self.invoke_oauth_tests('testoidc_oauth_callback', 'testoidc_oauth_attach', 'testoidc',
'cool.user', 'someusername')
def test_oidc_ldap_auth(self):
# Test with database auth.
oidc_mocks = self._get_oidc_mocks()
with mock_ldap() as ldap:
with AuthForTesting(ldap):
with HTTMock(*oidc_mocks):
self.invoke_oauth_tests('testoidc_oauth_callback', 'testoidc_oauth_attach', 'testoidc',
'cool.user', 'cool_user', test_attach=False)
if __name__ == '__main__':
unittest.main()

819
test/test_secscan.py Normal file
View file

@ -0,0 +1,819 @@
import json
import time
import unittest
from app import app, storage, notification_queue, url_scheme_and_hostname
from data import model
from data.database import Image, IMAGE_NOT_SCANNED_ENGINE_VERSION
from endpoints.v2 import v2_bp
from initdb import setup_database_for_testing, finished_database_for_testing
from notifications.notificationevent import VulnerabilityFoundEvent
from util.secscan.secscan_util import get_blob_download_uri_getter
from util.morecollections import AttrDict
from util.secscan.api import SecurityScannerAPI, APIRequestFailure
from util.secscan.analyzer import LayerAnalyzer
from util.secscan.fake import fake_security_scanner
from util.secscan.notifier import SecurityNotificationHandler, ProcessNotificationPageResult
from util.security.instancekeys import InstanceKeys
from workers.security_notification_worker import SecurityNotificationWorker
ADMIN_ACCESS_USER = 'devtable'
SIMPLE_REPO = 'simple'
COMPLEX_REPO = 'complex'
def process_notification_data(notification_data):
handler = SecurityNotificationHandler(100)
result = handler.process_notification_page_data(notification_data)
handler.send_notifications()
return result == ProcessNotificationPageResult.FINISHED_PROCESSING
class TestSecurityScanner(unittest.TestCase):
def setUp(self):
# Enable direct download in fake storage.
storage.put_content(['local_us'], 'supports_direct_download', 'true')
# Have fake storage say all files exist for the duration of the test.
storage.put_content(['local_us'], 'all_files_exist', 'true')
# Setup the database with fake storage.
setup_database_for_testing(self)
self.app = app.test_client()
self.ctx = app.test_request_context()
self.ctx.__enter__()
instance_keys = InstanceKeys(app)
self.api = SecurityScannerAPI(app.config, storage, app.config['SERVER_HOSTNAME'], app.config['HTTPCLIENT'],
uri_creator=get_blob_download_uri_getter(app.test_request_context('/'),
url_scheme_and_hostname),
instance_keys=instance_keys)
def tearDown(self):
storage.remove(['local_us'], 'supports_direct_download')
storage.remove(['local_us'], 'all_files_exist')
finished_database_for_testing(self)
self.ctx.__exit__(True, None, None)
def assertAnalyzed(self, layer, security_scanner, isAnalyzed, engineVersion):
self.assertEquals(isAnalyzed, layer.security_indexed)
self.assertEquals(engineVersion, layer.security_indexed_engine)
if isAnalyzed:
self.assertTrue(security_scanner.has_layer(security_scanner.layer_id(layer)))
# Ensure all parent layers are marked as analyzed.
parents = model.image.get_parent_images(ADMIN_ACCESS_USER, SIMPLE_REPO, layer)
for parent in parents:
self.assertTrue(parent.security_indexed)
self.assertEquals(engineVersion, parent.security_indexed_engine)
self.assertTrue(security_scanner.has_layer(security_scanner.layer_id(parent)))
def test_get_layer(self):
""" Test for basic retrieval of layers from the security scanner. """
layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest', include_storage=True)
with fake_security_scanner() as security_scanner:
# Ensure the layer doesn't exist yet.
self.assertFalse(security_scanner.has_layer(security_scanner.layer_id(layer)))
self.assertIsNone(self.api.get_layer_data(layer))
# Add the layer.
security_scanner.add_layer(security_scanner.layer_id(layer))
# Retrieve the results.
result = self.api.get_layer_data(layer, include_vulnerabilities=True)
self.assertIsNotNone(result)
self.assertEquals(result['Layer']['Name'], security_scanner.layer_id(layer))
def test_analyze_layer_nodirectdownload_success(self):
""" Tests analyzing a layer when direct download is disabled. """
# Disable direct download in fake storage.
storage.put_content(['local_us'], 'supports_direct_download', 'false')
try:
app.register_blueprint(v2_bp, url_prefix='/v2')
except:
# Already registered.
pass
layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest', include_storage=True)
self.assertFalse(layer.security_indexed)
self.assertEquals(-1, layer.security_indexed_engine)
# Ensure that the download is a registry+JWT download.
uri, auth_header = self.api._get_image_url_and_auth(layer)
self.assertIsNotNone(uri)
self.assertIsNotNone(auth_header)
# Ensure the download doesn't work without the header.
rv = self.app.head(uri)
self.assertEquals(rv.status_code, 401)
# Ensure the download works with the header. Note we use a HEAD here, as GET causes DB
# access which messes with the test runner's rollback.
rv = self.app.head(uri, headers=[('authorization', auth_header)])
self.assertEquals(rv.status_code, 200)
# Ensure the code works when called via analyze.
with fake_security_scanner() as security_scanner:
analyzer = LayerAnalyzer(app.config, self.api)
analyzer.analyze_recursively(layer)
layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest')
self.assertAnalyzed(layer, security_scanner, True, 1)
def test_analyze_layer_success(self):
""" Tests that analyzing a layer successfully marks it as analyzed. """
layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest', include_storage=True)
self.assertFalse(layer.security_indexed)
self.assertEquals(-1, layer.security_indexed_engine)
with fake_security_scanner() as security_scanner:
analyzer = LayerAnalyzer(app.config, self.api)
analyzer.analyze_recursively(layer)
layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest')
self.assertAnalyzed(layer, security_scanner, True, 1)
def test_analyze_layer_failure(self):
""" Tests that failing to analyze a layer (because it 422s) marks it as analyzed but failed. """
layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest', include_storage=True)
self.assertFalse(layer.security_indexed)
self.assertEquals(-1, layer.security_indexed_engine)
with fake_security_scanner() as security_scanner:
security_scanner.set_fail_layer_id(security_scanner.layer_id(layer))
analyzer = LayerAnalyzer(app.config, self.api)
analyzer.analyze_recursively(layer)
layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest')
self.assertAnalyzed(layer, security_scanner, False, 1)
def test_analyze_layer_internal_error(self):
""" Tests that failing to analyze a layer (because it 500s) marks it as not analyzed. """
layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest', include_storage=True)
self.assertFalse(layer.security_indexed)
self.assertEquals(-1, layer.security_indexed_engine)
with fake_security_scanner() as security_scanner:
security_scanner.set_internal_error_layer_id(security_scanner.layer_id(layer))
analyzer = LayerAnalyzer(app.config, self.api)
with self.assertRaises(APIRequestFailure):
analyzer.analyze_recursively(layer)
layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest')
self.assertAnalyzed(layer, security_scanner, False, -1)
def test_analyze_layer_error(self):
""" Tests that failing to analyze a layer (because it 400s) marks it as analyzed but failed. """
layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest', include_storage=True)
self.assertFalse(layer.security_indexed)
self.assertEquals(-1, layer.security_indexed_engine)
with fake_security_scanner() as security_scanner:
# Make is so trying to analyze the parent will fail with an error.
security_scanner.set_error_layer_id(security_scanner.layer_id(layer.parent))
# Try to the layer and its parents, but with one request causing an error.
analyzer = LayerAnalyzer(app.config, self.api)
analyzer.analyze_recursively(layer)
# Make sure it is marked as analyzed, but in a failed state.
layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest')
self.assertAnalyzed(layer, security_scanner, False, 1)
def test_analyze_layer_unexpected_status(self):
""" Tests that a response from a scanner with an unexpected status code fails correctly. """
layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest', include_storage=True)
self.assertFalse(layer.security_indexed)
self.assertEquals(-1, layer.security_indexed_engine)
with fake_security_scanner() as security_scanner:
# Make is so trying to analyze the parent will fail with an error.
security_scanner.set_unexpected_status_layer_id(security_scanner.layer_id(layer.parent))
# Try to the layer and its parents, but with one request causing an error.
analyzer = LayerAnalyzer(app.config, self.api)
with self.assertRaises(APIRequestFailure):
analyzer.analyze_recursively(layer)
# Make sure it isn't analyzed.
layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest')
self.assertAnalyzed(layer, security_scanner, False, -1)
def test_analyze_layer_missing_parent_handled(self):
""" Tests that a missing parent causes an automatic reanalysis, which succeeds. """
layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest', include_storage=True)
self.assertFalse(layer.security_indexed)
self.assertEquals(-1, layer.security_indexed_engine)
with fake_security_scanner() as security_scanner:
# Analyze the layer and its parents.
analyzer = LayerAnalyzer(app.config, self.api)
analyzer.analyze_recursively(layer)
# Make sure it was analyzed.
layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest')
self.assertAnalyzed(layer, security_scanner, True, 1)
# Mark the layer as not yet scanned.
layer.security_indexed_engine = IMAGE_NOT_SCANNED_ENGINE_VERSION
layer.security_indexed = False
layer.save()
# Remove the layer's parent entirely from the security scanner.
security_scanner.remove_layer(security_scanner.layer_id(layer.parent))
# Analyze again, which should properly re-analyze the missing parent and this layer.
analyzer.analyze_recursively(layer)
layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest')
self.assertAnalyzed(layer, security_scanner, True, 1)
def test_analyze_layer_invalid_parent(self):
""" Tests that trying to reanalyze a parent that is invalid causes the layer to be marked
as analyzed, but failed.
"""
layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest', include_storage=True)
self.assertFalse(layer.security_indexed)
self.assertEquals(-1, layer.security_indexed_engine)
with fake_security_scanner() as security_scanner:
# Analyze the layer and its parents.
analyzer = LayerAnalyzer(app.config, self.api)
analyzer.analyze_recursively(layer)
# Make sure it was analyzed.
layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest')
self.assertAnalyzed(layer, security_scanner, True, 1)
# Mark the layer as not yet scanned.
layer.security_indexed_engine = IMAGE_NOT_SCANNED_ENGINE_VERSION
layer.security_indexed = False
layer.save()
# Remove the layer's parent entirely from the security scanner.
security_scanner.remove_layer(security_scanner.layer_id(layer.parent))
# Make is so trying to analyze the parent will fail.
security_scanner.set_error_layer_id(security_scanner.layer_id(layer.parent))
# Try to analyze again, which should try to reindex the parent and fail.
analyzer.analyze_recursively(layer)
layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest')
self.assertAnalyzed(layer, security_scanner, False, 1)
def test_analyze_layer_unsupported_parent(self):
""" Tests that attempting to analyze a layer whose parent is unanalyzable, results in the layer
being marked as analyzed, but failed.
"""
layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest', include_storage=True)
self.assertFalse(layer.security_indexed)
self.assertEquals(-1, layer.security_indexed_engine)
with fake_security_scanner() as security_scanner:
# Make is so trying to analyze the parent will fail.
security_scanner.set_fail_layer_id(security_scanner.layer_id(layer.parent))
# Attempt to the layer and its parents. This should mark the layer itself as unanalyzable.
analyzer = LayerAnalyzer(app.config, self.api)
analyzer.analyze_recursively(layer)
layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest')
self.assertAnalyzed(layer, security_scanner, False, 1)
def test_analyze_layer_missing_storage(self):
""" Tests trying to analyze a layer with missing storage. """
layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest', include_storage=True)
self.assertFalse(layer.security_indexed)
self.assertEquals(-1, layer.security_indexed_engine)
# Delete the storage for the layer.
path = model.storage.get_layer_path(layer.storage)
locations = app.config['DISTRIBUTED_STORAGE_PREFERENCE']
storage.remove(locations, path)
storage.remove(locations, 'all_files_exist')
with fake_security_scanner() as security_scanner:
analyzer = LayerAnalyzer(app.config, self.api)
analyzer.analyze_recursively(layer)
layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest')
self.assertAnalyzed(layer, security_scanner, False, 1)
def assert_analyze_layer_notify(self, security_indexed_engine, security_indexed,
expect_notification):
layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest', include_storage=True)
self.assertFalse(layer.security_indexed)
self.assertEquals(-1, layer.security_indexed_engine)
# Ensure there are no existing events.
self.assertIsNone(notification_queue.get())
# Add a repo event for the layer.
repo = model.repository.get_repository(ADMIN_ACCESS_USER, SIMPLE_REPO)
model.notification.create_repo_notification(repo, 'vulnerability_found',
'quay_notification', {}, {'level': 100})
# Update the layer's state before analyzing.
layer.security_indexed_engine = security_indexed_engine
layer.security_indexed = security_indexed
layer.save()
with fake_security_scanner() as security_scanner:
security_scanner.set_vulns(security_scanner.layer_id(layer), [
{
"Name": "CVE-2014-9471",
"Namespace": "debian:8",
"Description": "Some service",
"Link": "https://security-tracker.debian.org/tracker/CVE-2014-9471",
"Severity": "Low",
"FixedBy": "9.23-5"
},
{
"Name": "CVE-2016-7530",
"Namespace": "debian:8",
"Description": "Some other service",
"Link": "https://security-tracker.debian.org/tracker/CVE-2016-7530",
"Severity": "Unknown",
"FixedBy": "19.343-2"
}
])
analyzer = LayerAnalyzer(app.config, self.api)
analyzer.analyze_recursively(layer)
layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest')
self.assertAnalyzed(layer, security_scanner, True, 1)
# Ensure an event was written for the tag (if necessary).
time.sleep(1)
queue_item = notification_queue.get()
if expect_notification:
self.assertIsNotNone(queue_item)
body = json.loads(queue_item.body)
self.assertEquals(set(['latest', 'prod']), set(body['event_data']['tags']))
self.assertEquals('CVE-2014-9471', body['event_data']['vulnerability']['id'])
self.assertEquals('Low', body['event_data']['vulnerability']['priority'])
self.assertTrue(body['event_data']['vulnerability']['has_fix'])
self.assertEquals('CVE-2014-9471', body['event_data']['vulnerabilities'][0]['id'])
self.assertEquals(2, len(body['event_data']['vulnerabilities']))
# Ensure we get the correct event message out as well.
event = VulnerabilityFoundEvent()
msg = '1 Low and 1 more vulnerabilities were detected in repository devtable/simple in 2 tags'
self.assertEquals(msg, event.get_summary(body['event_data'], {}))
self.assertEquals('info', event.get_level(body['event_data'], {}))
else:
self.assertIsNone(queue_item)
# Ensure its security indexed engine was updated.
updated_layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest')
self.assertEquals(updated_layer.id, layer.id)
self.assertTrue(updated_layer.security_indexed_engine > 0)
def test_analyze_layer_success_events(self):
# Not previously indexed at all => Notification
self.assert_analyze_layer_notify(IMAGE_NOT_SCANNED_ENGINE_VERSION, False, True)
def test_analyze_layer_success_no_notification(self):
# Previously successfully indexed => No notification
self.assert_analyze_layer_notify(0, True, False)
def test_analyze_layer_failed_then_success_notification(self):
# Previously failed to index => Notification
self.assert_analyze_layer_notify(0, False, True)
def test_notification_new_layers_not_vulnerable(self):
layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest', include_storage=True)
layer_id = '%s.%s' % (layer.docker_image_id, layer.storage.uuid)
# Add a repo event for the layer.
repo = model.repository.get_repository(ADMIN_ACCESS_USER, SIMPLE_REPO)
model.notification.create_repo_notification(repo, 'vulnerability_found', 'quay_notification',
{}, {'level': 100})
# Ensure that there are no event queue items for the layer.
self.assertIsNone(notification_queue.get())
# Fire off the notification processing.
with fake_security_scanner() as security_scanner:
analyzer = LayerAnalyzer(app.config, self.api)
analyzer.analyze_recursively(layer)
layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest')
self.assertAnalyzed(layer, security_scanner, True, 1)
# Add a notification for the layer.
notification_data = security_scanner.add_notification([layer_id], [], {}, {})
# Process the notification.
self.assertTrue(process_notification_data(notification_data))
# Ensure that there are no event queue items for the layer.
self.assertIsNone(notification_queue.get())
def test_notification_delete(self):
layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest', include_storage=True)
layer_id = '%s.%s' % (layer.docker_image_id, layer.storage.uuid)
# Add a repo event for the layer.
repo = model.repository.get_repository(ADMIN_ACCESS_USER, SIMPLE_REPO)
model.notification.create_repo_notification(repo, 'vulnerability_found', 'quay_notification',
{}, {'level': 100})
# Ensure that there are no event queue items for the layer.
self.assertIsNone(notification_queue.get())
# Fire off the notification processing.
with fake_security_scanner() as security_scanner:
analyzer = LayerAnalyzer(app.config, self.api)
analyzer.analyze_recursively(layer)
layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest')
self.assertAnalyzed(layer, security_scanner, True, 1)
# Add a notification for the layer.
notification_data = security_scanner.add_notification([layer_id], None, {}, None)
# Process the notification.
self.assertTrue(process_notification_data(notification_data))
# Ensure that there are no event queue items for the layer.
self.assertIsNone(notification_queue.get())
def test_notification_new_layers(self):
layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest', include_storage=True)
layer_id = '%s.%s' % (layer.docker_image_id, layer.storage.uuid)
# Add a repo event for the layer.
repo = model.repository.get_repository(ADMIN_ACCESS_USER, SIMPLE_REPO)
model.notification.create_repo_notification(repo, 'vulnerability_found', 'quay_notification',
{}, {'level': 100})
# Ensure that there are no event queue items for the layer.
self.assertIsNone(notification_queue.get())
# Fire off the notification processing.
with fake_security_scanner() as security_scanner:
analyzer = LayerAnalyzer(app.config, self.api)
analyzer.analyze_recursively(layer)
layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest')
self.assertAnalyzed(layer, security_scanner, True, 1)
vuln_info = {
"Name": "CVE-TEST",
"Namespace": "debian:8",
"Description": "Some service",
"Link": "https://security-tracker.debian.org/tracker/CVE-2014-9471",
"Severity": "Low",
"FixedIn": {"Version": "9.23-5"},
}
security_scanner.set_vulns(layer_id, [vuln_info])
# Add a notification for the layer.
notification_data = security_scanner.add_notification([], [layer_id], vuln_info, vuln_info)
# Process the notification.
self.assertTrue(process_notification_data(notification_data))
# Ensure an event was written for the tag.
time.sleep(1)
queue_item = notification_queue.get()
self.assertIsNotNone(queue_item)
item_body = json.loads(queue_item.body)
self.assertEquals(sorted(['prod', 'latest']), sorted(item_body['event_data']['tags']))
self.assertEquals('CVE-TEST', item_body['event_data']['vulnerability']['id'])
self.assertEquals('Low', item_body['event_data']['vulnerability']['priority'])
self.assertTrue(item_body['event_data']['vulnerability']['has_fix'])
def test_notification_no_new_layers(self):
layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest', include_storage=True)
# Add a repo event for the layer.
repo = model.repository.get_repository(ADMIN_ACCESS_USER, SIMPLE_REPO)
model.notification.create_repo_notification(repo, 'vulnerability_found', 'quay_notification',
{}, {'level': 100})
# Ensure that there are no event queue items for the layer.
self.assertIsNone(notification_queue.get())
# Fire off the notification processing.
with fake_security_scanner() as security_scanner:
analyzer = LayerAnalyzer(app.config, self.api)
analyzer.analyze_recursively(layer)
layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest')
self.assertAnalyzed(layer, security_scanner, True, 1)
# Add a notification for the layer.
notification_data = security_scanner.add_notification([], [], {}, {})
# Process the notification.
self.assertTrue(process_notification_data(notification_data))
# Ensure that there are no event queue items for the layer.
self.assertIsNone(notification_queue.get())
def notification_tuple(self, notification):
# TODO: Replace this with a method once we refactor the notification stuff into its
# own module.
return AttrDict({
'event_config_dict': json.loads(notification.event_config_json),
'method_config_dict': json.loads(notification.config_json),
})
def test_notification_no_new_layers_increased_severity(self):
layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest', include_storage=True)
layer_id = '%s.%s' % (layer.docker_image_id, layer.storage.uuid)
# Add a repo event for the layer.
repo = model.repository.get_repository(ADMIN_ACCESS_USER, SIMPLE_REPO)
notification = model.notification.create_repo_notification(repo, 'vulnerability_found',
'quay_notification', {},
{'level': 100})
# Ensure that there are no event queue items for the layer.
self.assertIsNone(notification_queue.get())
# Fire off the notification processing.
with fake_security_scanner() as security_scanner:
analyzer = LayerAnalyzer(app.config, self.api)
analyzer.analyze_recursively(layer)
layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest')
self.assertAnalyzed(layer, security_scanner, True, 1)
old_vuln_info = {
"Name": "CVE-TEST",
"Namespace": "debian:8",
"Description": "Some service",
"Link": "https://security-tracker.debian.org/tracker/CVE-2014-9471",
"Severity": "Low",
}
new_vuln_info = {
"Name": "CVE-TEST",
"Namespace": "debian:8",
"Description": "Some service",
"Link": "https://security-tracker.debian.org/tracker/CVE-2014-9471",
"Severity": "Critical",
"FixedIn": {'Version': "9.23-5"},
}
security_scanner.set_vulns(layer_id, [new_vuln_info])
# Add a notification for the layer.
notification_data = security_scanner.add_notification([layer_id], [layer_id],
old_vuln_info, new_vuln_info)
# Process the notification.
self.assertTrue(process_notification_data(notification_data))
# Ensure an event was written for the tag.
time.sleep(1)
queue_item = notification_queue.get()
self.assertIsNotNone(queue_item)
item_body = json.loads(queue_item.body)
self.assertEquals(sorted(['prod', 'latest']), sorted(item_body['event_data']['tags']))
self.assertEquals('CVE-TEST', item_body['event_data']['vulnerability']['id'])
self.assertEquals('Critical', item_body['event_data']['vulnerability']['priority'])
self.assertTrue(item_body['event_data']['vulnerability']['has_fix'])
# Verify that an event would be raised.
event_data = item_body['event_data']
notification = self.notification_tuple(notification)
self.assertTrue(VulnerabilityFoundEvent().should_perform(event_data, notification))
# Create another notification with a matching level and verify it will be raised.
notification = model.notification.create_repo_notification(repo, 'vulnerability_found',
'quay_notification', {},
{'level': 1})
notification = self.notification_tuple(notification)
self.assertTrue(VulnerabilityFoundEvent().should_perform(event_data, notification))
# Create another notification with a higher level and verify it won't be raised.
notification = model.notification.create_repo_notification(repo, 'vulnerability_found',
'quay_notification', {},
{'level': 0})
notification = self.notification_tuple(notification)
self.assertFalse(VulnerabilityFoundEvent().should_perform(event_data, notification))
def test_select_images_to_scan(self):
# Set all images to have a security index of a version to that of the config.
expected_version = app.config['SECURITY_SCANNER_ENGINE_VERSION_TARGET']
Image.update(security_indexed_engine=expected_version).execute()
# Ensure no images are available for scanning.
self.assertIsNone(model.image.get_min_id_for_sec_scan(expected_version))
self.assertTrue(len(model.image.get_images_eligible_for_scan(expected_version)) == 0)
# Check for a higher version.
self.assertIsNotNone(model.image.get_min_id_for_sec_scan(expected_version + 1))
self.assertTrue(len(model.image.get_images_eligible_for_scan(expected_version + 1)) > 0)
def test_notification_worker(self):
layer1 = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest', include_storage=True)
layer2 = model.tag.get_tag_image(ADMIN_ACCESS_USER, COMPLEX_REPO, 'prod', include_storage=True)
# Add a repo events for the layers.
simple_repo = model.repository.get_repository(ADMIN_ACCESS_USER, SIMPLE_REPO)
complex_repo = model.repository.get_repository(ADMIN_ACCESS_USER, COMPLEX_REPO)
model.notification.create_repo_notification(simple_repo, 'vulnerability_found',
'quay_notification', {}, {'level': 100})
model.notification.create_repo_notification(complex_repo, 'vulnerability_found',
'quay_notification', {}, {'level': 100})
# Ensure that there are no event queue items for the layer.
self.assertIsNone(notification_queue.get())
with fake_security_scanner() as security_scanner:
# Test with an unknown notification.
worker = SecurityNotificationWorker(None)
self.assertFalse(worker.perform_notification_work({
'Name': 'unknownnotification'
}))
# Add some analyzed layers.
analyzer = LayerAnalyzer(app.config, self.api)
analyzer.analyze_recursively(layer1)
analyzer.analyze_recursively(layer2)
# Add a notification with pages of data.
new_vuln_info = {
"Name": "CVE-TEST",
"Namespace": "debian:8",
"Description": "Some service",
"Link": "https://security-tracker.debian.org/tracker/CVE-2014-9471",
"Severity": "Critical",
"FixedIn": {'Version': "9.23-5"},
}
security_scanner.set_vulns(security_scanner.layer_id(layer1), [new_vuln_info])
security_scanner.set_vulns(security_scanner.layer_id(layer2), [new_vuln_info])
layer_ids = [security_scanner.layer_id(layer1), security_scanner.layer_id(layer2)]
notification_data = security_scanner.add_notification([], layer_ids, None, new_vuln_info)
# Test with a known notification with pages.
data = {
'Name': notification_data['Name'],
}
worker = SecurityNotificationWorker(None)
self.assertTrue(worker.perform_notification_work(data, layer_limit=2))
# Make sure all pages were processed by ensuring we have two notifications.
time.sleep(1)
self.assertIsNotNone(notification_queue.get())
self.assertIsNotNone(notification_queue.get())
def test_notification_worker_offset_pages_not_indexed(self):
# Try without indexes.
self.assert_notification_worker_offset_pages(indexed=False)
def test_notification_worker_offset_pages_indexed(self):
# Try with indexes.
self.assert_notification_worker_offset_pages(indexed=True)
def assert_notification_worker_offset_pages(self, indexed=False):
layer1 = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest', include_storage=True)
layer2 = model.tag.get_tag_image(ADMIN_ACCESS_USER, COMPLEX_REPO, 'prod', include_storage=True)
# Add a repo events for the layers.
simple_repo = model.repository.get_repository(ADMIN_ACCESS_USER, SIMPLE_REPO)
complex_repo = model.repository.get_repository(ADMIN_ACCESS_USER, COMPLEX_REPO)
model.notification.create_repo_notification(simple_repo, 'vulnerability_found',
'quay_notification', {}, {'level': 100})
model.notification.create_repo_notification(complex_repo, 'vulnerability_found',
'quay_notification', {}, {'level': 100})
# Ensure that there are no event queue items for the layer.
self.assertIsNone(notification_queue.get())
with fake_security_scanner() as security_scanner:
# Test with an unknown notification.
worker = SecurityNotificationWorker(None)
self.assertFalse(worker.perform_notification_work({
'Name': 'unknownnotification'
}))
# Add some analyzed layers.
analyzer = LayerAnalyzer(app.config, self.api)
analyzer.analyze_recursively(layer1)
analyzer.analyze_recursively(layer2)
# Add a notification with pages of data.
new_vuln_info = {
"Name": "CVE-TEST",
"Namespace": "debian:8",
"Description": "Some service",
"Link": "https://security-tracker.debian.org/tracker/CVE-2014-9471",
"Severity": "Critical",
"FixedIn": {'Version': "9.23-5"},
}
security_scanner.set_vulns(security_scanner.layer_id(layer1), [new_vuln_info])
security_scanner.set_vulns(security_scanner.layer_id(layer2), [new_vuln_info])
# Define offsetting sets of layer IDs, to test cross-pagination support. In this test, we
# will only serve 2 layer IDs per page: the first page will serve both of the 'New' layer IDs,
# but since the first 2 'Old' layer IDs are "earlier" than the shared ID of
# `devtable/simple:latest`, they won't get served in the 'New' list until the *second* page.
# The notification handling system should correctly not notify for this layer, even though it
# is marked 'New' on page 1 and marked 'Old' on page 2. Clair will served these
# IDs sorted in the same manner.
idx_old_layer_ids = [{'LayerName': 'old1', 'Index': 1},
{'LayerName': 'old2', 'Index': 2},
{'LayerName': security_scanner.layer_id(layer1), 'Index': 3}]
idx_new_layer_ids = [{'LayerName': security_scanner.layer_id(layer1), 'Index': 3},
{'LayerName': security_scanner.layer_id(layer2), 'Index': 4}]
old_layer_ids = [t['LayerName'] for t in idx_old_layer_ids]
new_layer_ids = [t['LayerName'] for t in idx_new_layer_ids]
if not indexed:
idx_old_layer_ids = None
idx_new_layer_ids = None
notification_data = security_scanner.add_notification(old_layer_ids, new_layer_ids, None,
new_vuln_info, max_per_page=2,
indexed_old_layer_ids=idx_old_layer_ids,
indexed_new_layer_ids=idx_new_layer_ids)
# Test with a known notification with pages.
data = {
'Name': notification_data['Name'],
}
worker = SecurityNotificationWorker(None)
self.assertTrue(worker.perform_notification_work(data, layer_limit=2))
# Make sure all pages were processed by ensuring we have only one notification. If the second
# page was not processed, then the `Old` entry for layer1 will not be found, and we'd get two
# notifications.
time.sleep(1)
self.assertIsNotNone(notification_queue.get())
self.assertIsNone(notification_queue.get())
def test_layer_gc(self):
layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest', include_storage=True)
# Delete the prod tag so that only the `latest` tag remains.
model.tag.delete_tag(ADMIN_ACCESS_USER, SIMPLE_REPO, 'prod')
with fake_security_scanner() as security_scanner:
# Analyze the layer.
analyzer = LayerAnalyzer(app.config, self.api)
analyzer.analyze_recursively(layer)
layer = model.tag.get_tag_image(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest')
self.assertAnalyzed(layer, security_scanner, True, 1)
self.assertTrue(security_scanner.has_layer(security_scanner.layer_id(layer)))
namespace_user = model.user.get_user(ADMIN_ACCESS_USER)
model.user.change_user_tag_expiration(namespace_user, 0)
# Delete the tag in the repository and GC.
model.tag.delete_tag(ADMIN_ACCESS_USER, SIMPLE_REPO, 'latest')
time.sleep(1)
repo = model.repository.get_repository(ADMIN_ACCESS_USER, SIMPLE_REPO)
model.gc.garbage_collect_repo(repo)
# Ensure that the security scanner no longer has the image.
self.assertFalse(security_scanner.has_layer(security_scanner.layer_id(layer)))
if __name__ == '__main__':
unittest.main()

4
test/test_sni.py Normal file
View file

@ -0,0 +1,4 @@
import ssl
def test_sni_support():
assert ssl.HAS_SNI

View file

@ -0,0 +1,121 @@
import unittest
import endpoints.decorated
import json
from app import app
from util.names import parse_namespace_repository
from initdb import setup_database_for_testing, finished_database_for_testing
from specs import build_v1_index_specs
from endpoints.v1 import v1_bp
app.register_blueprint(v1_bp, url_prefix='/v1')
NO_ACCESS_USER = 'freshuser'
READ_ACCESS_USER = 'reader'
CREATOR_ACCESS_USER = 'creator'
ADMIN_ACCESS_USER = 'devtable'
class EndpointTestCase(unittest.TestCase):
def setUp(self):
setup_database_for_testing(self)
def tearDown(self):
finished_database_for_testing(self)
class _SpecTestBuilder(type):
@staticmethod
def _test_generator(url, expected_status, open_kwargs, session_var_list):
def test(self):
with app.test_client() as c:
if session_var_list:
# Temporarily remove the teardown functions
teardown_funcs = []
if None in app.teardown_request_funcs:
teardown_funcs = app.teardown_request_funcs[None]
app.teardown_request_funcs[None] = []
with c.session_transaction() as sess:
for sess_key, sess_val in session_var_list:
sess[sess_key] = sess_val
# Restore the teardown functions
app.teardown_request_funcs[None] = teardown_funcs
rv = c.open(url, **open_kwargs)
msg = '%s %s: %s expected: %s' % (open_kwargs['method'], url,
rv.status_code, expected_status)
self.assertEqual(rv.status_code, expected_status, msg)
return test
def __new__(cls, name, bases, attrs):
with app.test_request_context() as ctx:
specs = attrs['spec_func']()
for test_spec in specs:
url, open_kwargs = test_spec.get_client_args()
if attrs['auth_username']:
basic_auth = test_spec.gen_basic_auth(attrs['auth_username'],
'password')
open_kwargs['headers'] = [('authorization', '%s' % basic_auth)]
session_vars = []
if test_spec.sess_repo:
ns, repo = parse_namespace_repository(test_spec.sess_repo, 'library')
session_vars.append(('namespace', ns))
session_vars.append(('repository', repo))
expected_status = getattr(test_spec, attrs['result_attr'])
test = _SpecTestBuilder._test_generator(url, expected_status,
open_kwargs,
session_vars)
test_name_url = url.replace('/', '_').replace('-', '_')
sess_repo = str(test_spec.sess_repo).replace('/', '_')
test_name = 'test_%s%s_%s_%s' % (open_kwargs['method'].lower(),
test_name_url, sess_repo, attrs['result_attr'])
attrs[test_name] = test
return type(name, bases, attrs)
class TestAnonymousAccess(EndpointTestCase):
__metaclass__ = _SpecTestBuilder
spec_func = build_v1_index_specs
result_attr = 'anon_code'
auth_username = None
class TestNoAccess(EndpointTestCase):
__metaclass__ = _SpecTestBuilder
spec_func = build_v1_index_specs
result_attr = 'no_access_code'
auth_username = NO_ACCESS_USER
class TestReadAccess(EndpointTestCase):
__metaclass__ = _SpecTestBuilder
spec_func = build_v1_index_specs
result_attr = 'read_code'
auth_username = READ_ACCESS_USER
class TestCreatorAccess(EndpointTestCase):
__metaclass__ = _SpecTestBuilder
spec_func = build_v1_index_specs
result_attr = 'creator_code'
auth_username = CREATOR_ACCESS_USER
class TestAdminAccess(EndpointTestCase):
__metaclass__ = _SpecTestBuilder
spec_func = build_v1_index_specs
result_attr = 'admin_code'
auth_username = ADMIN_ACCESS_USER
if __name__ == '__main__':
unittest.main()

View file

@ -0,0 +1,116 @@
import unittest
import json
import endpoints.decorated # Register the various exceptions via decorators.
from app import app
from endpoints.v2 import v2_bp
from initdb import setup_database_for_testing, finished_database_for_testing
from test.specs import build_v2_index_specs
app.register_blueprint(v2_bp, url_prefix='/v2')
NO_ACCESS_USER = 'freshuser'
READ_ACCESS_USER = 'reader'
ADMIN_ACCESS_USER = 'devtable'
CREATOR_ACCESS_USER = 'creator'
class EndpointTestCase(unittest.TestCase):
def setUp(self):
setup_database_for_testing(self)
def tearDown(self):
finished_database_for_testing(self)
class _SpecTestBuilder(type):
@staticmethod
def _test_generator(url, test_spec, attrs):
def test(self):
with app.test_client() as c:
headers = []
expected_index_status = getattr(test_spec, attrs['result_attr'])
if attrs['auth_username']:
# Get a signed JWT.
username = attrs['auth_username']
password = 'password'
jwt_scope = test_spec.get_scope_string()
query_string = 'service=' + app.config['SERVER_HOSTNAME'] + '&scope=' + jwt_scope
arv = c.open('/v2/auth',
headers=[('authorization', test_spec.gen_basic_auth(username, password))],
query_string=query_string)
msg = 'Auth failed for %s %s: got %s, expected: 200' % (
test_spec.method_name, test_spec.index_name, arv.status_code)
self.assertEqual(arv.status_code, 200, msg)
headers = [('authorization', 'Bearer ' + json.loads(arv.data)['token'])]
rv = c.open(url, headers=headers, method=test_spec.method_name)
msg = '%s %s: got %s, expected: %s (auth: %s | headers %s)' % (test_spec.method_name,
test_spec.index_name, rv.status_code, expected_index_status, attrs['auth_username'],
len(headers))
self.assertEqual(rv.status_code, expected_index_status, msg)
return test
def __new__(cls, name, bases, attrs):
with app.test_request_context() as ctx:
specs = attrs['spec_func']()
for test_spec in specs:
test_name = '%s_%s_%s_%s_%s' % (test_spec.index_name, test_spec.method_name,
test_spec.repo_name, attrs['auth_username'] or 'anon',
attrs['result_attr'])
test_name = test_name.replace('/', '_').replace('-', '_')
test_name = 'test_' + test_name.lower().replace('v2.', 'v2_')
url = test_spec.get_url()
attrs[test_name] = _SpecTestBuilder._test_generator(url, test_spec, attrs)
return type(name, bases, attrs)
class TestAnonymousAccess(EndpointTestCase):
__metaclass__ = _SpecTestBuilder
spec_func = build_v2_index_specs
result_attr = 'anon_code'
auth_username = None
class TestNoAccess(EndpointTestCase):
__metaclass__ = _SpecTestBuilder
spec_func = build_v2_index_specs
result_attr = 'no_access_code'
auth_username = NO_ACCESS_USER
class TestReadAccess(EndpointTestCase):
__metaclass__ = _SpecTestBuilder
spec_func = build_v2_index_specs
result_attr = 'read_code'
auth_username = READ_ACCESS_USER
class TestCreatorAccess(EndpointTestCase):
__metaclass__ = _SpecTestBuilder
spec_func = build_v2_index_specs
result_attr = 'creator_code'
auth_username = CREATOR_ACCESS_USER
class TestAdminAccess(EndpointTestCase):
__metaclass__ = _SpecTestBuilder
spec_func = build_v2_index_specs
result_attr = 'admin_code'
auth_username = ADMIN_ACCESS_USER
if __name__ == '__main__':
unittest.main()

114
test/testconfig.py Normal file
View file

@ -0,0 +1,114 @@
import os
from datetime import datetime, timedelta
from tempfile import NamedTemporaryFile
from config import DefaultConfig
class FakeTransaction(object):
def __enter__(self):
return self
def __exit__(self, exc_type, value, traceback):
pass
TEST_DB_FILE = NamedTemporaryFile(delete=True)
class TestConfig(DefaultConfig):
TESTING = True
SECRET_KEY = "superdupersecret!!!1"
DATABASE_SECRET_KEY = 'anothercrazykey!'
BILLING_TYPE = 'FakeStripe'
TEST_DB_FILE = TEST_DB_FILE
DB_URI = os.environ.get('TEST_DATABASE_URI', 'sqlite:///{0}'.format(TEST_DB_FILE.name))
DB_CONNECTION_ARGS = {
'threadlocals': True,
'autorollback': True,
}
@staticmethod
def create_transaction(db):
return FakeTransaction()
DB_TRANSACTION_FACTORY = create_transaction
DISTRIBUTED_STORAGE_CONFIG = {'local_us': ['FakeStorage', {}], 'local_eu': ['FakeStorage', {}]}
DISTRIBUTED_STORAGE_PREFERENCE = ['local_us']
BUILDLOGS_MODULE_AND_CLASS = ('test.testlogs', 'testlogs.TestBuildLogs')
BUILDLOGS_OPTIONS = ['devtable', 'building', 'deadbeef-dead-beef-dead-beefdeadbeef', False]
USERFILES_LOCATION = 'local_us'
USERFILES_PATH= "userfiles/"
FEATURE_SUPER_USERS = True
FEATURE_BILLING = True
FEATURE_MAILING = True
SUPER_USERS = ['devtable']
LICENSE_USER_LIMIT = 500
LICENSE_EXPIRATION = datetime.now() + timedelta(weeks=520)
LICENSE_EXPIRATION_WARNING = datetime.now() + timedelta(weeks=520)
FEATURE_GITHUB_BUILD = True
FEATURE_BITTORRENT = True
FEATURE_ACI_CONVERSION = True
CLOUDWATCH_NAMESPACE = None
FEATURE_SECURITY_SCANNER = True
FEATURE_SECURITY_NOTIFICATIONS = True
SECURITY_SCANNER_ENDPOINT = 'http://fakesecurityscanner/'
SECURITY_SCANNER_API_VERSION = 'v1'
SECURITY_SCANNER_ENGINE_VERSION_TARGET = 1
SECURITY_SCANNER_API_TIMEOUT_SECONDS = 1
FEATURE_SIGNING = True
SIGNING_ENGINE = 'gpg2'
GPG2_PRIVATE_KEY_NAME = 'EEB32221'
GPG2_PRIVATE_KEY_FILENAME = 'test/data/signing-private.gpg'
GPG2_PUBLIC_KEY_FILENAME = 'test/data/signing-public.gpg'
INSTANCE_SERVICE_KEY_KID_LOCATION = 'test/data/test.kid'
INSTANCE_SERVICE_KEY_LOCATION = 'test/data/test.pem'
PROMETHEUS_AGGREGATOR_URL = None
GITHUB_LOGIN_CONFIG = {}
GOOGLE_LOGIN_CONFIG = {}
FEATURE_GITHUB_LOGIN = True
FEATURE_GOOGLE_LOGIN = True
TESTOIDC_LOGIN_CONFIG = {
'CLIENT_ID': 'foo',
'CLIENT_SECRET': 'bar',
'OIDC_SERVER': 'http://fakeoidc',
'DEBUGGING': True,
'LOGIN_BINDING_FIELD': 'sub',
}
RECAPTCHA_SITE_KEY = 'somekey'
RECAPTCHA_SECRET_KEY = 'somesecretkey'
FEATURE_APP_REGISTRY = True
FEATURE_TEAM_SYNCING = True
FEATURE_CHANGE_TAG_EXPIRATION = True
TAG_EXPIRATION_OPTIONS = ['0s', '1s', '1d', '1w', '2w', '4w']
DEFAULT_NAMESPACE_MAXIMUM_BUILD_COUNT = None
DATA_MODEL_CACHE_CONFIG = {
'engine': 'inmemory',
}
FEATURE_REPO_MIRROR = True
V3_UPGRADE_MODE = 'complete'

238
test/testlogs.py Normal file
View file

@ -0,0 +1,238 @@
import logging
import datetime
from random import SystemRandom
from functools import wraps, partial
from copy import deepcopy
from jinja2.utils import generate_lorem_ipsum
from data.buildlogs import RedisBuildLogs
logger = logging.getLogger(__name__)
random = SystemRandom()
get_sentence = partial(generate_lorem_ipsum, html=False, n=1, min=5, max=10)
def maybe_advance_script(is_get_status=False):
def inner_advance(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
advance_units = random.randint(1, 500)
logger.debug('Advancing script %s units', advance_units)
while advance_units > 0 and self.remaining_script:
units = self.remaining_script[0][0]
if advance_units > units:
advance_units -= units
self.advance_script(is_get_status)
else:
break
return func(self, *args, **kwargs)
return wrapper
return inner_advance
class TestBuildLogs(RedisBuildLogs):
COMMAND_TYPES = ['FROM', 'MAINTAINER', 'RUN', 'CMD', 'EXPOSE', 'ENV', 'ADD',
'ENTRYPOINT', 'VOLUME', 'USER', 'WORKDIR']
STATUS_TEMPLATE = {
'total_commands': None,
'current_command': None,
'push_completion': 0.0,
'pull_completion': 0.0,
}
def __init__(self, redis_config, namespace, repository, test_build_id, allow_delegate=True):
super(TestBuildLogs, self).__init__(redis_config)
self.namespace = namespace
self.repository = repository
self.test_build_id = test_build_id
self.allow_delegate = allow_delegate
self.remaining_script = self._generate_script()
logger.debug('Total script size: %s', len(self.remaining_script))
self._logs = []
self._status = {}
self._last_status = {}
def advance_script(self, is_get_status):
(_, log, status_wrapper) = self.remaining_script.pop(0)
if log is not None:
self._logs.append(log)
if status_wrapper is not None:
(phase, status) = status_wrapper
if not is_get_status:
from data import model
build_obj = model.build.get_repository_build(self.test_build_id)
build_obj.phase = phase
build_obj.save()
self._status = status
self._last_status = status
def _generate_script(self):
script = []
# generate the init phase
script.append(self._generate_phase(400, 'initializing'))
script.extend(self._generate_logs(random.randint(1, 3)))
# move to the building phase
script.append(self._generate_phase(400, 'building'))
total_commands = random.randint(5, 20)
for command_num in range(1, total_commands + 1):
command_weight = random.randint(50, 100)
script.append(self._generate_command(command_num, total_commands, command_weight))
# we want 0 logs some percent of the time
num_logs = max(0, random.randint(-50, 400))
script.extend(self._generate_logs(num_logs))
# move to the pushing phase
script.append(self._generate_phase(400, 'pushing'))
script.extend(self._generate_push_statuses(total_commands))
# move to the error or complete phase
if random.randint(0, 1) == 0:
script.append(self._generate_phase(400, 'complete'))
else:
script.append(self._generate_phase(400, 'error'))
script.append((1, {'message': 'Something bad happened! Oh noes!',
'type': self.ERROR}, None))
return script
def _generate_phase(self, start_weight, phase_name):
message = {
'message': phase_name,
'type': self.PHASE,
'datetime': str(datetime.datetime.now())
}
return (start_weight, message,
(phase_name, deepcopy(self.STATUS_TEMPLATE)))
def _generate_command(self, command_num, total_commands, command_weight):
sentence = get_sentence()
command = random.choice(self.COMMAND_TYPES)
if command == 'FROM':
sentence = random.choice(['ubuntu', 'lopter/raring-base',
'quay.io/devtable/simple',
'quay.io/buynlarge/orgrepo',
'stackbrew/ubuntu:precise'])
msg = {
'message': 'Step %s: %s %s' % (command_num, command, sentence),
'type': self.COMMAND,
'datetime': str(datetime.datetime.now())
}
status = deepcopy(self.STATUS_TEMPLATE)
status['total_commands'] = total_commands
status['current_command'] = command_num
return (command_weight, msg, ('building', status))
@staticmethod
def _generate_logs(count):
others = []
if random.randint(0, 10) <= 8:
premessage = {
'message': '\x1b[91m' + get_sentence(),
'data': {'datetime': str(datetime.datetime.now())}
}
postmessage = {
'message': '\x1b[0m',
'data': {'datetime': str(datetime.datetime.now())}
}
count = count - 2
others = [(1, premessage, None), (1, postmessage, None)]
def get_message():
return {
'message': get_sentence(),
'data': {'datetime': str(datetime.datetime.now())}
}
return others + [(1, get_message(), None) for _ in range(count)]
@staticmethod
def _compute_total_completion(statuses, total_images):
percentage_with_sizes = float(len(statuses.values()))/total_images
sent_bytes = sum([status[u'current'] for status in statuses.values()])
total_bytes = sum([status[u'total'] for status in statuses.values()])
return float(sent_bytes)/total_bytes*percentage_with_sizes
@staticmethod
def _generate_push_statuses(total_commands):
push_status_template = deepcopy(TestBuildLogs.STATUS_TEMPLATE)
push_status_template['current_command'] = total_commands
push_status_template['total_commands'] = total_commands
push_statuses = []
one_mb = 1 * 1024 * 1024
num_images = random.randint(2, 7)
sizes = [random.randint(one_mb, one_mb * 5) for _ in range(num_images)]
image_completion = {}
for image_num, image_size in enumerate(sizes):
image_id = 'image_id_%s' % image_num
image_completion[image_id] = {
'current': 0,
'total': image_size,
}
for i in range(one_mb, image_size, one_mb):
image_completion[image_id]['current'] = i
new_status = deepcopy(push_status_template)
completion = TestBuildLogs._compute_total_completion(image_completion,
num_images)
new_status['push_completion'] = completion
push_statuses.append((250, None, ('pushing', new_status)))
return push_statuses
@maybe_advance_script()
def get_log_entries(self, build_id, start_index):
if build_id == self.test_build_id:
return (len(self._logs), self._logs[start_index:])
elif not self.allow_delegate:
return None
else:
return super(TestBuildLogs, self).get_log_entries(build_id, start_index)
@maybe_advance_script(True)
def get_status(self, build_id):
if build_id == self.test_build_id:
returnable_status = self._last_status
self._last_status = self._status
return returnable_status
elif not self.allow_delegate:
return None
else:
return super(TestBuildLogs, self).get_status(build_id)
def expire_log_entries(self, build_id):
if build_id == self.test_build_id:
return
if not self.allow_delegate:
return None
else:
return super(TestBuildLogs, self).expire_log_entries(build_id)
def delete_log_entries(self, build_id):
if build_id == self.test_build_id:
return
if not self.allow_delegate:
return None
else:
return super(TestBuildLogs, self).delete_log_entries(build_id)