Add ability for specific geographic regions to be blocked from pulling images within a namespace

This commit is contained in:
Joseph Schorr 2018-12-05 15:19:37 -05:00
parent c71a43a06c
commit c3710a6a5e
20 changed files with 257 additions and 37 deletions

View file

@ -14,3 +14,8 @@ def for_catalog_page(auth_context_key, start_id, limit):
""" Returns a cache key for a single page of a catalog lookup for an authed context. """ """ Returns a cache key for a single page of a catalog lookup for an authed context. """
params = (auth_context_key or '(anon)', start_id or 0, limit or 0) params = (auth_context_key or '(anon)', start_id or 0, limit or 0)
return CacheKey('catalog_page__%s_%s_%s' % params, '60s') return CacheKey('catalog_page__%s_%s_%s' % params, '60s')
def for_namespace_geo_restrictions(namespace_name):
""" Returns a cache key for the geo restrictions for a namespace """
return CacheKey('geo_restrictions__%s' % (namespace_name), '240s')

View file

@ -504,7 +504,8 @@ class User(BaseModel):
RepositoryNotification, OAuthAuthorizationCode, RepositoryNotification, OAuthAuthorizationCode,
RepositoryActionCount, TagManifestLabel, RepositoryActionCount, TagManifestLabel,
TeamSync, RepositorySearchScore, TeamSync, RepositorySearchScore,
DeletedNamespace} | appr_classes | v22_classes | transition_classes DeletedNamespace,
NamespaceGeoRestriction} | appr_classes | v22_classes | transition_classes
delete_instance_filtered(self, User, delete_nullable, skip_transitive_deletes) delete_instance_filtered(self, User, delete_nullable, skip_transitive_deletes)
@ -525,6 +526,21 @@ class DeletedNamespace(BaseModel):
queue_id = CharField(null=True, index=True) queue_id = CharField(null=True, index=True)
class NamespaceGeoRestriction(BaseModel):
namespace = QuayUserField(index=True, allows_robots=False)
added = DateTimeField(default=datetime.utcnow)
description = CharField()
unstructured_json = JSONField()
restricted_region_iso_code = CharField(index=True)
class Meta:
database = db
read_slaves = (read_slave,)
indexes = (
(('namespace', 'restricted_region_iso_code'), True),
)
class UserPromptTypes(object): class UserPromptTypes(object):
CONFIRM_USERNAME = 'confirm_username' CONFIRM_USERNAME = 'confirm_username'
ENTER_NAME = 'enter_name' ENTER_NAME = 'enter_name'

View file

@ -0,0 +1,46 @@
"""Add NamespaceGeoRestriction table
Revision ID: 54492a68a3cf
Revises: c00a1f15968b
Create Date: 2018-12-05 15:12:14.201116
"""
# revision identifiers, used by Alembic.
revision = '54492a68a3cf'
down_revision = 'c00a1f15968b'
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import mysql
def upgrade(tables, tester):
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('namespacegeorestriction',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('namespace_id', sa.Integer(), nullable=False),
sa.Column('added', sa.DateTime(), nullable=False),
sa.Column('description', sa.String(length=255), nullable=False),
sa.Column('unstructured_json', sa.Text(), nullable=False),
sa.Column('restricted_region_iso_code', sa.String(length=255), nullable=False),
sa.ForeignKeyConstraint(['namespace_id'], ['user.id'], name=op.f('fk_namespacegeorestriction_namespace_id_user')),
sa.PrimaryKeyConstraint('id', name=op.f('pk_namespacegeorestriction'))
)
op.create_index('namespacegeorestriction_namespace_id', 'namespacegeorestriction', ['namespace_id'], unique=False)
op.create_index('namespacegeorestriction_namespace_id_restricted_region_iso_code', 'namespacegeorestriction', ['namespace_id', 'restricted_region_iso_code'], unique=True)
op.create_index('namespacegeorestriction_restricted_region_iso_code', 'namespacegeorestriction', ['restricted_region_iso_code'], unique=False)
# ### end Alembic commands ###
tester.populate_table('namespacegeorestriction', [
('namespace_id', tester.TestDataType.Foreign('user')),
('added', tester.TestDataType.DateTime),
('description', tester.TestDataType.String),
('unstructured_json', tester.TestDataType.JSON),
('restricted_region_iso_code', tester.TestDataType.String),
])
def downgrade(tables, tester):
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('namespacegeorestriction')
# ### end Alembic commands ###

View file

@ -15,7 +15,7 @@ from data.database import (User, LoginService, FederatedLogin, RepositoryPermiss
UserRegion, ImageStorageLocation, UserRegion, ImageStorageLocation,
ServiceKeyApproval, OAuthApplication, RepositoryBuildTrigger, ServiceKeyApproval, OAuthApplication, RepositoryBuildTrigger,
UserPromptKind, UserPrompt, UserPromptTypes, DeletedNamespace, UserPromptKind, UserPrompt, UserPromptTypes, DeletedNamespace,
RobotAccountMetadata) RobotAccountMetadata, NamespaceGeoRestriction)
from data.model import (DataModelException, InvalidPasswordException, InvalidRobotException, from data.model import (DataModelException, InvalidPasswordException, InvalidRobotException,
InvalidUsernameException, InvalidEmailAddressException, InvalidUsernameException, InvalidEmailAddressException,
TooManyLoginAttemptsException, db_transaction, TooManyLoginAttemptsException, db_transaction,
@ -1060,6 +1060,14 @@ def get_federated_logins(user_ids, service_name):
LoginService.name == service_name)) LoginService.name == service_name))
def list_namespace_geo_restrictions(namespace_name):
""" Returns all of the defined geographic restrictions for the given namespace. """
return (NamespaceGeoRestriction
.select()
.join(User)
.where(User.username == namespace_name))
class LoginWrappedDBUser(UserMixin): class LoginWrappedDBUser(UserMixin):
def __init__(self, user_uuid, db_user=None): def __init__(self, user_uuid, db_user=None):
self._uuid = user_uuid self._uuid = user_uuid

View file

@ -316,3 +316,9 @@ class RegistryDataInterface(object):
""" Creates a manifest under the repository and sets a temporary tag to point to it. """ Creates a manifest under the repository and sets a temporary tag to point to it.
Returns the manifest object created or None on error. Returns the manifest object created or None on error.
""" """
@abstractmethod
def get_cached_namespace_region_blacklist(self, model_cache, namespace_name):
""" Returns a cached set of ISO country codes blacklisted for pulls for the namespace
or None if the list could not be loaded.
"""

View file

@ -121,6 +121,27 @@ class SharedModel:
torrent_info = model.storage.save_torrent_info(image_storage, piece_length, pieces) torrent_info = model.storage.save_torrent_info(image_storage, piece_length, pieces)
return TorrentInfo.for_torrent_info(torrent_info) return TorrentInfo.for_torrent_info(torrent_info)
def get_cached_namespace_region_blacklist(self, model_cache, namespace_name):
""" Returns a cached set of ISO country codes blacklisted for pulls for the namespace
or None if the list could not be loaded.
"""
def load_blacklist():
restrictions = model.user.list_namespace_geo_restrictions(namespace_name)
if restrictions is None:
return None
return [restriction.restricted_region_iso_code for restriction in restrictions]
blacklist_cache_key = cache_key.for_namespace_geo_restrictions(namespace_name)
result = model_cache.retrieve(blacklist_cache_key, load_blacklist)
if result is None:
return None
return set(result)
def get_cached_repo_blob(self, model_cache, namespace_name, repo_name, blob_digest): def get_cached_repo_blob(self, model_cache, namespace_name, repo_name, blob_digest):
""" """
Returns the blob in the repository with the given digest if any or None if none. Returns the blob in the repository with the given digest if any or None if none.

View file

@ -18,7 +18,7 @@ from endpoints.appr import appr_bp, require_app_repo_read, require_app_repo_writ
from endpoints.appr.cnr_backend import Blob, Channel, Package, User from endpoints.appr.cnr_backend import Blob, Channel, Package, User
from endpoints.appr.decorators import disallow_for_image_repository from endpoints.appr.decorators import disallow_for_image_repository
from endpoints.appr.models_cnr import model from endpoints.appr.models_cnr import model
from endpoints.decorators import anon_allowed, anon_protect from endpoints.decorators import anon_allowed, anon_protect, check_region_blacklisted
from util.names import REPOSITORY_NAME_REGEX, TAG_REGEX from util.names import REPOSITORY_NAME_REGEX, TAG_REGEX
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -71,6 +71,7 @@ def login():
strict_slashes=False,) strict_slashes=False,)
@process_auth @process_auth
@require_app_repo_read @require_app_repo_read
@check_region_blacklisted(namespace_name_kwarg='namespace')
@anon_protect @anon_protect
def blobs(namespace, package_name, digest): def blobs(namespace, package_name, digest):
reponame = repo_name(namespace, package_name) reponame = repo_name(namespace, package_name)
@ -114,6 +115,7 @@ def delete_package(namespace, package_name, release, media_type):
methods=['GET'], strict_slashes=False) methods=['GET'], strict_slashes=False)
@process_auth @process_auth
@require_app_repo_read @require_app_repo_read
@check_region_blacklisted(namespace_name_kwarg='namespace')
@anon_protect @anon_protect
def show_package(namespace, package_name, release, media_type): def show_package(namespace, package_name, release, media_type):
reponame = repo_name(namespace, package_name) reponame = repo_name(namespace, package_name)
@ -152,6 +154,7 @@ def show_package_release_manifests(namespace, package_name, release):
strict_slashes=False,) strict_slashes=False,)
@process_auth @process_auth
@require_app_repo_read @require_app_repo_read
@check_region_blacklisted(namespace_name_kwarg='namespace')
@anon_protect @anon_protect
def pull(namespace, package_name, release, media_type): def pull(namespace, package_name, release, media_type):
logger.debug('Pull of release %s of app repository %s/%s', release, namespace, package_name) logger.debug('Pull of release %s of app repository %s/%s', release, namespace, package_name)

View file

@ -1,5 +1,6 @@
""" Various decorators for endpoint and API handlers. """ """ Various decorators for endpoint and API handlers. """
import os
import logging import logging
from functools import wraps from functools import wraps
@ -7,8 +8,9 @@ from flask import abort, request, make_response
import features import features
from app import app from app import app, ip_resolver, model_cache
from auth.auth_context import get_authenticated_context from auth.auth_context import get_authenticated_context
from data.registry_model import registry_model
from util.names import parse_namespace_repository, ImplicitLibraryNamespaceNotAllowed from util.names import parse_namespace_repository, ImplicitLibraryNamespaceNotAllowed
from util.http import abort from util.http import abort
@ -122,3 +124,40 @@ def require_xhr_from_browser(func):
return func(*args, **kwargs) return func(*args, **kwargs)
return wrapper return wrapper
def check_region_blacklisted(error_class=None, namespace_name_kwarg=None):
""" Decorator which checks if the incoming request is from a region geo IP blocked
for the current namespace. The first argument to the wrapped function must be
the namespace name.
"""
def wrapper(wrapped):
@wraps(wrapped)
def decorated(*args, **kwargs):
if namespace_name_kwarg:
namespace_name = kwargs[namespace_name_kwarg]
else:
namespace_name = args[0]
region_blacklist = registry_model.get_cached_namespace_region_blacklist(model_cache,
namespace_name)
if region_blacklist:
# Resolve the IP information and block if on the namespace's blacklist.
remote_addr = request.remote_addr
if os.getenv('TEST', 'false').lower() == 'true':
remote_addr = request.headers.get('X-Override-Remote-Addr-For-Testing', remote_addr)
resolved_ip_info = ip_resolver.resolve_ip(remote_addr)
logger.debug('Resolved IP information for IP %s: %s', remote_addr, resolved_ip_info)
if (resolved_ip_info and
resolved_ip_info.country_iso_code and
resolved_ip_info.country_iso_code in region_blacklist):
if error_class:
raise error_class()
abort(403, 'Pulls of this data have been restricted geographically')
return wrapped(*args, **kwargs)
return decorated
return wrapper

View file

@ -18,7 +18,7 @@ from data.registry_model.manifestbuilder import lookup_manifest_builder
from digest import checksums from digest import checksums
from endpoints.v1 import v1_bp from endpoints.v1 import v1_bp
from endpoints.v1.index import ensure_namespace_enabled from endpoints.v1.index import ensure_namespace_enabled
from endpoints.decorators import anon_protect from endpoints.decorators import anon_protect, check_region_blacklisted
from util.http import abort, exact_abort from util.http import abort, exact_abort
from util.registry.replication import queue_storage_replication from util.registry.replication import queue_storage_replication
@ -109,6 +109,7 @@ def head_image_layer(namespace, repository, image_id, headers):
@ensure_namespace_enabled @ensure_namespace_enabled
@require_completion @require_completion
@set_cache_headers @set_cache_headers
@check_region_blacklisted()
@anon_protect @anon_protect
def get_image_layer(namespace, repository, image_id, headers): def get_image_layer(namespace, repository, image_id, headers):
permission = ReadRepositoryPermission(namespace, repository) permission = ReadRepositoryPermission(namespace, repository)

View file

@ -13,11 +13,11 @@ from data.registry_model.blobuploader import (create_blob_upload, retrieve_blob_
BlobUploadException, BlobTooLargeException, BlobUploadException, BlobTooLargeException,
BlobRangeMismatchException) BlobRangeMismatchException)
from digest import digest_tools from digest import digest_tools
from endpoints.decorators import anon_protect, parse_repository_name from endpoints.decorators import anon_protect, parse_repository_name, check_region_blacklisted
from endpoints.v2 import v2_bp, require_repo_read, require_repo_write, get_input_stream from endpoints.v2 import v2_bp, require_repo_read, require_repo_write, get_input_stream
from endpoints.v2.errors import ( from endpoints.v2.errors import (
BlobUnknown, BlobUploadInvalid, BlobUploadUnknown, Unsupported, NameUnknown, LayerTooLarge, BlobUnknown, BlobUploadInvalid, BlobUploadUnknown, Unsupported, NameUnknown, LayerTooLarge,
InvalidRequest) InvalidRequest, BlobDownloadGeoBlocked)
from util.cache import cache_control from util.cache import cache_control
from util.names import parse_namespace_repository from util.names import parse_namespace_repository
@ -65,6 +65,7 @@ def check_blob_exists(namespace_name, repo_name, digest):
@process_registry_jwt_auth(scopes=['pull']) @process_registry_jwt_auth(scopes=['pull'])
@require_repo_read @require_repo_read
@anon_protect @anon_protect
@check_region_blacklisted(BlobDownloadGeoBlocked)
@cache_control(max_age=31536000) @cache_control(max_age=31536000)
def download_blob(namespace_name, repo_name, digest): def download_blob(namespace_name, repo_name, digest):
# Find the blob. # Find the blob.

View file

@ -144,3 +144,10 @@ class NamespaceDisabled(V2RegistryException):
def __init__(self, message=None): def __init__(self, message=None):
message = message or 'This namespace is disabled. Please contact your system administrator.' message = message or 'This namespace is disabled. Please contact your system administrator.'
super(NamespaceDisabled, self).__init__('NAMESPACE_DISABLED', message, {}, 400) super(NamespaceDisabled, self).__init__('NAMESPACE_DISABLED', message, {}, 400)
class BlobDownloadGeoBlocked(V2RegistryException):
def __init__(self, detail=None):
message = ('The region from which you are pulling has been geo-ip blocked. ' +
'Please contact the namespace owner.')
super(BlobDownloadGeoBlocked, self).__init__('BLOB_DOWNLOAD_GEO_BLOCKED', message, detail, 403)

View file

@ -13,7 +13,8 @@ from auth.permissions import ReadRepositoryPermission
from data import database from data import database
from data import model from data import model
from data.registry_model import registry_model from data.registry_model import registry_model
from endpoints.decorators import anon_protect, anon_allowed, route_show_if, parse_repository_name from endpoints.decorators import (anon_protect, anon_allowed, route_show_if, parse_repository_name,
check_region_blacklisted)
from endpoints.v2.blob import BLOB_DIGEST_ROUTE from endpoints.v2.blob import BLOB_DIGEST_ROUTE
from image.appc import AppCImageFormatter from image.appc import AppCImageFormatter
from image.docker import ManifestException from image.docker import ManifestException
@ -273,6 +274,7 @@ def _repo_verb_signature(namespace, repository, tag_name, verb, checker=None, **
return make_response(signature_value) return make_response(signature_value)
@check_region_blacklisted()
def _repo_verb(namespace, repository, tag_name, verb, formatter, sign=False, checker=None, def _repo_verb(namespace, repository, tag_name, verb, formatter, sign=False, checker=None,
**kwargs): **kwargs):
# Verify that the image exists and that we have access to it. # Verify that the image exists and that we have access to it.
@ -444,6 +446,7 @@ def get_squashed_tag(namespace, repository, tag):
@verbs.route('/torrent{0}'.format(BLOB_DIGEST_ROUTE), methods=['GET']) @verbs.route('/torrent{0}'.format(BLOB_DIGEST_ROUTE), methods=['GET'])
@process_auth @process_auth
@parse_repository_name() @parse_repository_name()
@check_region_blacklisted(namespace_name_kwarg='namespace_name')
def get_tag_torrent(namespace_name, repo_name, digest): def get_tag_torrent(namespace_name, repo_name, digest):
repo = model.repository.get_repository(namespace_name, repo_name) repo = model.repository.get_repository(namespace_name, repo_name)
repo_is_public = repo is not None and model.repository.is_repository_public(repo) repo_is_public = repo is not None and model.repository.is_repository_public(repo)

View file

@ -920,7 +920,8 @@ def populate_database(minimal=False, with_storage=False):
model.repositoryactioncount.update_repository_score(to_count) model.repositoryactioncount.update_repository_score(to_count)
WHITELISTED_EMPTY_MODELS = ['DeletedNamespace', 'LogEntry2', 'ManifestChild'] WHITELISTED_EMPTY_MODELS = ['DeletedNamespace', 'LogEntry2', 'ManifestChild',
'NamespaceGeoRestriction']
def find_models_missing_data(): def find_models_missing_data():
# As a sanity check we are going to make sure that all db tables have some data, unless explicitly # As a sanity check we are going to make sure that all db tables have some data, unless explicitly

View file

@ -15,7 +15,7 @@ from flask_principal import Identity
from app import storage from app import storage
from data.database import (close_db_filter, configure, DerivedStorageForImage, QueueItem, Image, from data.database import (close_db_filter, configure, DerivedStorageForImage, QueueItem, Image,
TagManifest, TagManifestToManifest, Manifest, ManifestLegacyImage, TagManifest, TagManifestToManifest, Manifest, ManifestLegacyImage,
ManifestBlob) ManifestBlob, NamespaceGeoRestriction, User)
from data import model from data import model
from data.registry_model import registry_model from data.registry_model import registry_model
from endpoints.csrf import generate_csrf_token from endpoints.csrf import generate_csrf_token
@ -116,6 +116,13 @@ def registry_server_executor(app):
TagManifest.delete().execute() TagManifest.delete().execute()
return 'OK' return 'OK'
def set_geo_block_for_namespace(namespace_name, iso_country_code):
NamespaceGeoRestriction.create(namespace=User.get(username=namespace_name),
description='',
unstructured_json={},
restricted_region_iso_code=iso_country_code)
return 'OK'
executor = LiveServerExecutor() executor = LiveServerExecutor()
executor.register('generate_csrf', generate_csrf) executor.register('generate_csrf', generate_csrf)
executor.register('set_supports_direct_download', set_supports_direct_download) executor.register('set_supports_direct_download', set_supports_direct_download)
@ -130,6 +137,7 @@ def registry_server_executor(app):
executor.register('create_app_repository', create_app_repository) executor.register('create_app_repository', create_app_repository)
executor.register('disable_namespace', disable_namespace) executor.register('disable_namespace', disable_namespace)
executor.register('delete_manifests', delete_manifests) executor.register('delete_manifests', delete_manifests)
executor.register('set_geo_block_for_namespace', set_geo_block_for_namespace)
return executor return executor

View file

@ -15,6 +15,7 @@ class V1ProtocolSteps(Enum):
PUT_TAG = 'put-tag' PUT_TAG = 'put-tag'
PUT_IMAGE_JSON = 'put-image-json' PUT_IMAGE_JSON = 'put-image-json'
DELETE_TAG = 'delete-tag' DELETE_TAG = 'delete-tag'
GET_LAYER = 'get-layer'
class V1Protocol(RegistryProtocol): class V1Protocol(RegistryProtocol):
@ -45,6 +46,9 @@ class V1Protocol(RegistryProtocol):
Failures.INVALID_IMAGES: 400, Failures.INVALID_IMAGES: 400,
Failures.NAMESPACE_DISABLED: 400, Failures.NAMESPACE_DISABLED: 400,
}, },
V1ProtocolSteps.GET_LAYER: {
Failures.GEO_BLOCKED: 403,
},
} }
def __init__(self, jwk): def __init__(self, jwk):
@ -118,7 +122,10 @@ class V1Protocol(RegistryProtocol):
self.conduct(session, 'HEAD', image_prefix + 'layer', headers=headers) self.conduct(session, 'HEAD', image_prefix + 'layer', headers=headers)
# And retrieve the layer data. # And retrieve the layer data.
result = self.conduct(session, 'GET', image_prefix + 'layer', headers=headers) result = self.conduct(session, 'GET', image_prefix + 'layer', headers=headers,
expected_status=(200, expected_failure, V1ProtocolSteps.GET_LAYER),
options=options)
if result.status_code == 200:
assert result.content == images[index].bytes assert result.content == images[index].bytes
return PullResult(manifests=None, image_ids=image_ids) return PullResult(manifests=None, image_ids=image_ids)

View file

@ -27,6 +27,7 @@ class V2ProtocolSteps(Enum):
CATALOG = 'catalog' CATALOG = 'catalog'
LIST_TAGS = 'list-tags' LIST_TAGS = 'list-tags'
START_UPLOAD = 'start-upload' START_UPLOAD = 'start-upload'
GET_BLOB = 'get-blob'
class V2Protocol(RegistryProtocol): class V2Protocol(RegistryProtocol):
@ -48,6 +49,9 @@ class V2Protocol(RegistryProtocol):
Failures.UNAUTHORIZED: 401, Failures.UNAUTHORIZED: 401,
Failures.DISALLOWED_LIBRARY_NAMESPACE: 400, Failures.DISALLOWED_LIBRARY_NAMESPACE: 400,
}, },
V2ProtocolSteps.GET_BLOB: {
Failures.GEO_BLOCKED: 403,
},
V2ProtocolSteps.BLOB_HEAD_CHECK: { V2ProtocolSteps.BLOB_HEAD_CHECK: {
Failures.DISALLOWED_LIBRARY_NAMESPACE: 400, Failures.DISALLOWED_LIBRARY_NAMESPACE: 400,
}, },
@ -466,6 +470,7 @@ class V2Protocol(RegistryProtocol):
assert response.headers['Content-Length'] == str(len(blob_bytes)) assert response.headers['Content-Length'] == str(len(blob_bytes))
# And retrieve the blob data. # And retrieve the blob data.
if not options.skip_blob_push_checks:
result = self.conduct(session, 'GET', result = self.conduct(session, 'GET',
'/v2/%s/blobs/%s' % (self.repo_name(namespace, repo_name), blob_digest), '/v2/%s/blobs/%s' % (self.repo_name(namespace, repo_name), blob_digest),
headers=headers, expected_status=200) headers=headers, expected_status=200)
@ -558,8 +563,10 @@ class V2Protocol(RegistryProtocol):
result = self.conduct(session, 'GET', result = self.conduct(session, 'GET',
'/v2/%s/blobs/%s' % (self.repo_name(namespace, repo_name), '/v2/%s/blobs/%s' % (self.repo_name(namespace, repo_name),
blob_digest), blob_digest),
expected_status=expected_status, expected_status=(expected_status, expected_failure,
headers=headers) V2ProtocolSteps.GET_BLOB),
headers=headers,
options=options)
if expected_status == 200: if expected_status == 200:
assert result.content == image.bytes assert result.content == image.bytes

View file

@ -65,6 +65,7 @@ class Failures(Enum):
INVALID_BLOB = 'invalid-blob' INVALID_BLOB = 'invalid-blob'
NAMESPACE_DISABLED = 'namespace-disabled' NAMESPACE_DISABLED = 'namespace-disabled'
UNAUTHORIZED_FOR_MOUNT = 'unauthorized-for-mount' UNAUTHORIZED_FOR_MOUNT = 'unauthorized-for-mount'
GEO_BLOCKED = 'geo-blocked'
class ProtocolOptions(object): class ProtocolOptions(object):
@ -78,6 +79,8 @@ class ProtocolOptions(object):
self.accept_mimetypes = None self.accept_mimetypes = None
self.mount_blobs = None self.mount_blobs = None
self.push_by_manifest_digest = False self.push_by_manifest_digest = False
self.request_addr = None
self.skip_blob_push_checks = False
@add_metaclass(ABCMeta) @add_metaclass(ABCMeta)
@ -115,12 +118,16 @@ class RegistryProtocol(object):
return repo_name return repo_name
def conduct(self, session, method, url, expected_status=200, params=None, data=None, def conduct(self, session, method, url, expected_status=200, params=None, data=None,
json_data=None, headers=None, auth=None): json_data=None, headers=None, auth=None, options=None):
if json_data is not None: if json_data is not None:
data = json.dumps(json_data) data = json.dumps(json_data)
headers = headers or {} headers = headers or {}
headers['Content-Type'] = 'application/json' headers['Content-Type'] = 'application/json'
if options and options.request_addr:
headers = headers or {}
headers['X-Override-Remote-Addr-For-Testing'] = options.request_addr
if isinstance(expected_status, tuple): if isinstance(expected_status, tuple):
expected_status, expected_failure, protocol_step = expected_status expected_status, expected_failure, protocol_step = expected_status
if expected_failure is not None: if expected_failure is not None:

View file

@ -1706,3 +1706,25 @@ def test_verify_schema2(v22_protocol, basic_images, liveserver_session, liveserv
credentials=credentials) credentials=credentials)
manifest = result.manifests['latest'] manifest = result.manifests['latest']
assert manifest.schema_version == 2 assert manifest.schema_version == 2
def test_geo_blocking(pusher, puller, basic_images, liveserver_session,
liveserver, registry_server_executor, app_reloader):
""" Test: Attempt to pull an image from a geoblocked IP address. """
credentials = ('devtable', 'password')
options = ProtocolOptions()
options.skip_blob_push_checks = True # Otherwise, cache gets established.
# Push a new repository.
pusher.push(liveserver_session, 'devtable', 'newrepo', 'latest', basic_images,
credentials=credentials, options=options)
registry_server_executor.on(liveserver).set_geo_block_for_namespace('devtable', 'US')
# Attempt to pull the repository to verify. This should fail with a 403 due to
# the geoblocking of the IP being using.
options = ProtocolOptions()
options.request_addr = '6.0.0.0'
puller.pull(liveserver_session, 'devtable', 'newrepo', 'latest', basic_images,
credentials=credentials, options=options,
expected_failure=Failures.GEO_BLOCKED)

View file

@ -16,7 +16,8 @@ import requests
from util.abchelpers import nooper from util.abchelpers import nooper
ResolvedLocation = namedtuple('ResolvedLocation', ['provider', 'region', 'service', 'sync_token']) ResolvedLocation = namedtuple('ResolvedLocation', ['provider', 'region', 'service', 'sync_token',
'country_iso_code'])
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -124,7 +125,11 @@ class IPResolver(IPResolverInterface):
fails, returns None. fails, returns None.
""" """
location_function = self._get_location_function() location_function = self._get_location_function()
if not ip_address or not location_function: if not ip_address:
return None
if not location_function:
logger.debug('No location function could be defined for IP address resolution')
return None return None
return location_function(ip_address) return location_function(ip_address)
@ -142,8 +147,9 @@ class IPResolver(IPResolverInterface):
cache = CACHE cache = CACHE
sync_token = cache.get('sync_token', None) sync_token = cache.get('sync_token', None)
if sync_token is None: if sync_token is None:
logger.debug('The aws ip range has not been cached from %s', _DATA_FILES['aws-ip-ranges.json']) logger.debug('The aws ip range has not been cached from %s',
return None _DATA_FILES['aws-ip-ranges.json'])
return IPResolver._build_location_function(sync_token, set(), {}, self.geoip_db)
all_amazon = cache['all_amazon'] all_amazon = cache['all_amazon']
regions = cache['regions'] regions = cache['regions']
@ -163,20 +169,25 @@ class IPResolver(IPResolverInterface):
try: try:
parsed_ip = IPAddress(ip_address) parsed_ip = IPAddress(ip_address)
except AddrFormatError: except AddrFormatError:
return ResolvedLocation('invalid_ip', None, None, sync_token) return ResolvedLocation('invalid_ip', None, None, sync_token, None)
if parsed_ip not in all_amazon:
# Try geoip classification # Try geoip classification
try: try:
found = country_db.country(parsed_ip) geoinfo = country_db.country(parsed_ip)
except geoip2.errors.AddressNotFoundError:
geoinfo = None
if parsed_ip not in all_amazon:
if geoinfo:
return ResolvedLocation( return ResolvedLocation(
'internet', 'internet',
found.continent.code, geoinfo.continent.code,
found.country.iso_code, geoinfo.country.iso_code,
sync_token, sync_token,
geoinfo.country.iso_code,
) )
except geoip2.errors.AddressNotFoundError:
return ResolvedLocation('internet', None, None, sync_token) return ResolvedLocation('internet', None, None, sync_token, None)
region = None region = None
@ -185,7 +196,8 @@ class IPResolver(IPResolverInterface):
region = region_name region = region_name
break break
return ResolvedLocation('aws', region, None, sync_token) return ResolvedLocation('aws', region, None, sync_token,
geoinfo.country.country_iso_code if geoinfo else None)
return _get_location return _get_location
@staticmethod @staticmethod

View file

@ -44,17 +44,17 @@ def test_unstarted(app, test_aws_ip, unstarted_cache):
ipresolver = IPResolver(app) ipresolver = IPResolver(app)
with patch.dict('util.ipresolver.CACHE', unstarted_cache): with patch.dict('util.ipresolver.CACHE', unstarted_cache):
assert ipresolver.resolve_ip(test_aws_ip) is None assert ipresolver.resolve_ip(test_aws_ip) is not None
def test_resolved(aws_ip_range_data, test_ip_range_cache, test_aws_ip, app): def test_resolved(aws_ip_range_data, test_ip_range_cache, test_aws_ip, app):
with patch('util.ipresolver._UpdateIPRange'): with patch('util.ipresolver._UpdateIPRange'):
ipresolver = IPResolver(app) ipresolver = IPResolver(app)
with patch.dict('util.ipresolver.CACHE', test_ip_range_cache): with patch.dict('util.ipresolver.CACHE', test_ip_range_cache):
assert ipresolver.resolve_ip(test_aws_ip) == ResolvedLocation(provider='aws', region=u'GLOBAL', service=None, sync_token=123456789) assert ipresolver.resolve_ip(test_aws_ip) == ResolvedLocation(provider='aws', region=u'GLOBAL', service=None, sync_token=123456789, country_iso_code=None)
assert ipresolver.resolve_ip('10.0.0.2') == ResolvedLocation(provider='aws', region=u'GLOBAL', service=None, sync_token=123456789) assert ipresolver.resolve_ip('10.0.0.2') == ResolvedLocation(provider='aws', region=u'GLOBAL', service=None, sync_token=123456789, country_iso_code=None)
assert ipresolver.resolve_ip('1.2.3.4') == ResolvedLocation(provider='internet', region=u'NA', service=u'US', sync_token=123456789) assert ipresolver.resolve_ip('1.2.3.4') == ResolvedLocation(provider='internet', region=u'NA', service=u'US', sync_token=123456789, country_iso_code=u'US')
assert ipresolver.resolve_ip('127.0.0.1') == ResolvedLocation(provider='internet', region=None, service=None, sync_token=123456789) assert ipresolver.resolve_ip('127.0.0.1') == ResolvedLocation(provider='internet', region=None, service=None, sync_token=123456789, country_iso_code=None)
def test_thread_missing_file(): def test_thread_missing_file():
class LoopInterruptionForTest(Exception): class LoopInterruptionForTest(Exception):