Add ability for specific geographic regions to be blocked from pulling images within a namespace
This commit is contained in:
parent
c71a43a06c
commit
c3710a6a5e
20 changed files with 257 additions and 37 deletions
5
data/cache/cache_key.py
vendored
5
data/cache/cache_key.py
vendored
|
@ -14,3 +14,8 @@ def for_catalog_page(auth_context_key, start_id, limit):
|
|||
""" Returns a cache key for a single page of a catalog lookup for an authed context. """
|
||||
params = (auth_context_key or '(anon)', start_id or 0, limit or 0)
|
||||
return CacheKey('catalog_page__%s_%s_%s' % params, '60s')
|
||||
|
||||
|
||||
def for_namespace_geo_restrictions(namespace_name):
|
||||
""" Returns a cache key for the geo restrictions for a namespace """
|
||||
return CacheKey('geo_restrictions__%s' % (namespace_name), '240s')
|
||||
|
|
|
@ -504,7 +504,8 @@ class User(BaseModel):
|
|||
RepositoryNotification, OAuthAuthorizationCode,
|
||||
RepositoryActionCount, TagManifestLabel,
|
||||
TeamSync, RepositorySearchScore,
|
||||
DeletedNamespace} | appr_classes | v22_classes | transition_classes
|
||||
DeletedNamespace,
|
||||
NamespaceGeoRestriction} | appr_classes | v22_classes | transition_classes
|
||||
delete_instance_filtered(self, User, delete_nullable, skip_transitive_deletes)
|
||||
|
||||
|
||||
|
@ -525,6 +526,21 @@ class DeletedNamespace(BaseModel):
|
|||
queue_id = CharField(null=True, index=True)
|
||||
|
||||
|
||||
class NamespaceGeoRestriction(BaseModel):
|
||||
namespace = QuayUserField(index=True, allows_robots=False)
|
||||
added = DateTimeField(default=datetime.utcnow)
|
||||
description = CharField()
|
||||
unstructured_json = JSONField()
|
||||
restricted_region_iso_code = CharField(index=True)
|
||||
|
||||
class Meta:
|
||||
database = db
|
||||
read_slaves = (read_slave,)
|
||||
indexes = (
|
||||
(('namespace', 'restricted_region_iso_code'), True),
|
||||
)
|
||||
|
||||
|
||||
class UserPromptTypes(object):
|
||||
CONFIRM_USERNAME = 'confirm_username'
|
||||
ENTER_NAME = 'enter_name'
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
"""Add NamespaceGeoRestriction table
|
||||
|
||||
Revision ID: 54492a68a3cf
|
||||
Revises: c00a1f15968b
|
||||
Create Date: 2018-12-05 15:12:14.201116
|
||||
|
||||
"""
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '54492a68a3cf'
|
||||
down_revision = 'c00a1f15968b'
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import mysql
|
||||
|
||||
def upgrade(tables, tester):
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('namespacegeorestriction',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('namespace_id', sa.Integer(), nullable=False),
|
||||
sa.Column('added', sa.DateTime(), nullable=False),
|
||||
sa.Column('description', sa.String(length=255), nullable=False),
|
||||
sa.Column('unstructured_json', sa.Text(), nullable=False),
|
||||
sa.Column('restricted_region_iso_code', sa.String(length=255), nullable=False),
|
||||
sa.ForeignKeyConstraint(['namespace_id'], ['user.id'], name=op.f('fk_namespacegeorestriction_namespace_id_user')),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_namespacegeorestriction'))
|
||||
)
|
||||
op.create_index('namespacegeorestriction_namespace_id', 'namespacegeorestriction', ['namespace_id'], unique=False)
|
||||
op.create_index('namespacegeorestriction_namespace_id_restricted_region_iso_code', 'namespacegeorestriction', ['namespace_id', 'restricted_region_iso_code'], unique=True)
|
||||
op.create_index('namespacegeorestriction_restricted_region_iso_code', 'namespacegeorestriction', ['restricted_region_iso_code'], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
tester.populate_table('namespacegeorestriction', [
|
||||
('namespace_id', tester.TestDataType.Foreign('user')),
|
||||
('added', tester.TestDataType.DateTime),
|
||||
('description', tester.TestDataType.String),
|
||||
('unstructured_json', tester.TestDataType.JSON),
|
||||
('restricted_region_iso_code', tester.TestDataType.String),
|
||||
])
|
||||
|
||||
|
||||
def downgrade(tables, tester):
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('namespacegeorestriction')
|
||||
# ### end Alembic commands ###
|
|
@ -15,7 +15,7 @@ from data.database import (User, LoginService, FederatedLogin, RepositoryPermiss
|
|||
UserRegion, ImageStorageLocation,
|
||||
ServiceKeyApproval, OAuthApplication, RepositoryBuildTrigger,
|
||||
UserPromptKind, UserPrompt, UserPromptTypes, DeletedNamespace,
|
||||
RobotAccountMetadata)
|
||||
RobotAccountMetadata, NamespaceGeoRestriction)
|
||||
from data.model import (DataModelException, InvalidPasswordException, InvalidRobotException,
|
||||
InvalidUsernameException, InvalidEmailAddressException,
|
||||
TooManyLoginAttemptsException, db_transaction,
|
||||
|
@ -1060,6 +1060,14 @@ def get_federated_logins(user_ids, service_name):
|
|||
LoginService.name == service_name))
|
||||
|
||||
|
||||
def list_namespace_geo_restrictions(namespace_name):
|
||||
""" Returns all of the defined geographic restrictions for the given namespace. """
|
||||
return (NamespaceGeoRestriction
|
||||
.select()
|
||||
.join(User)
|
||||
.where(User.username == namespace_name))
|
||||
|
||||
|
||||
class LoginWrappedDBUser(UserMixin):
|
||||
def __init__(self, user_uuid, db_user=None):
|
||||
self._uuid = user_uuid
|
||||
|
|
|
@ -316,3 +316,9 @@ class RegistryDataInterface(object):
|
|||
""" Creates a manifest under the repository and sets a temporary tag to point to it.
|
||||
Returns the manifest object created or None on error.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_cached_namespace_region_blacklist(self, model_cache, namespace_name):
|
||||
""" Returns a cached set of ISO country codes blacklisted for pulls for the namespace
|
||||
or None if the list could not be loaded.
|
||||
"""
|
||||
|
|
|
@ -121,6 +121,27 @@ class SharedModel:
|
|||
torrent_info = model.storage.save_torrent_info(image_storage, piece_length, pieces)
|
||||
return TorrentInfo.for_torrent_info(torrent_info)
|
||||
|
||||
|
||||
def get_cached_namespace_region_blacklist(self, model_cache, namespace_name):
|
||||
""" Returns a cached set of ISO country codes blacklisted for pulls for the namespace
|
||||
or None if the list could not be loaded.
|
||||
"""
|
||||
|
||||
def load_blacklist():
|
||||
restrictions = model.user.list_namespace_geo_restrictions(namespace_name)
|
||||
if restrictions is None:
|
||||
return None
|
||||
|
||||
return [restriction.restricted_region_iso_code for restriction in restrictions]
|
||||
|
||||
blacklist_cache_key = cache_key.for_namespace_geo_restrictions(namespace_name)
|
||||
result = model_cache.retrieve(blacklist_cache_key, load_blacklist)
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
return set(result)
|
||||
|
||||
|
||||
def get_cached_repo_blob(self, model_cache, namespace_name, repo_name, blob_digest):
|
||||
"""
|
||||
Returns the blob in the repository with the given digest if any or None if none.
|
||||
|
|
|
@ -18,7 +18,7 @@ from endpoints.appr import appr_bp, require_app_repo_read, require_app_repo_writ
|
|||
from endpoints.appr.cnr_backend import Blob, Channel, Package, User
|
||||
from endpoints.appr.decorators import disallow_for_image_repository
|
||||
from endpoints.appr.models_cnr import model
|
||||
from endpoints.decorators import anon_allowed, anon_protect
|
||||
from endpoints.decorators import anon_allowed, anon_protect, check_region_blacklisted
|
||||
from util.names import REPOSITORY_NAME_REGEX, TAG_REGEX
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -71,6 +71,7 @@ def login():
|
|||
strict_slashes=False,)
|
||||
@process_auth
|
||||
@require_app_repo_read
|
||||
@check_region_blacklisted(namespace_name_kwarg='namespace')
|
||||
@anon_protect
|
||||
def blobs(namespace, package_name, digest):
|
||||
reponame = repo_name(namespace, package_name)
|
||||
|
@ -114,6 +115,7 @@ def delete_package(namespace, package_name, release, media_type):
|
|||
methods=['GET'], strict_slashes=False)
|
||||
@process_auth
|
||||
@require_app_repo_read
|
||||
@check_region_blacklisted(namespace_name_kwarg='namespace')
|
||||
@anon_protect
|
||||
def show_package(namespace, package_name, release, media_type):
|
||||
reponame = repo_name(namespace, package_name)
|
||||
|
@ -152,6 +154,7 @@ def show_package_release_manifests(namespace, package_name, release):
|
|||
strict_slashes=False,)
|
||||
@process_auth
|
||||
@require_app_repo_read
|
||||
@check_region_blacklisted(namespace_name_kwarg='namespace')
|
||||
@anon_protect
|
||||
def pull(namespace, package_name, release, media_type):
|
||||
logger.debug('Pull of release %s of app repository %s/%s', release, namespace, package_name)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
""" Various decorators for endpoint and API handlers. """
|
||||
|
||||
import os
|
||||
import logging
|
||||
|
||||
from functools import wraps
|
||||
|
@ -7,8 +8,9 @@ from flask import abort, request, make_response
|
|||
|
||||
import features
|
||||
|
||||
from app import app
|
||||
from app import app, ip_resolver, model_cache
|
||||
from auth.auth_context import get_authenticated_context
|
||||
from data.registry_model import registry_model
|
||||
from util.names import parse_namespace_repository, ImplicitLibraryNamespaceNotAllowed
|
||||
from util.http import abort
|
||||
|
||||
|
@ -122,3 +124,40 @@ def require_xhr_from_browser(func):
|
|||
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
def check_region_blacklisted(error_class=None, namespace_name_kwarg=None):
|
||||
""" Decorator which checks if the incoming request is from a region geo IP blocked
|
||||
for the current namespace. The first argument to the wrapped function must be
|
||||
the namespace name.
|
||||
"""
|
||||
def wrapper(wrapped):
|
||||
@wraps(wrapped)
|
||||
def decorated(*args, **kwargs):
|
||||
if namespace_name_kwarg:
|
||||
namespace_name = kwargs[namespace_name_kwarg]
|
||||
else:
|
||||
namespace_name = args[0]
|
||||
|
||||
region_blacklist = registry_model.get_cached_namespace_region_blacklist(model_cache,
|
||||
namespace_name)
|
||||
if region_blacklist:
|
||||
# Resolve the IP information and block if on the namespace's blacklist.
|
||||
remote_addr = request.remote_addr
|
||||
if os.getenv('TEST', 'false').lower() == 'true':
|
||||
remote_addr = request.headers.get('X-Override-Remote-Addr-For-Testing', remote_addr)
|
||||
|
||||
resolved_ip_info = ip_resolver.resolve_ip(remote_addr)
|
||||
logger.debug('Resolved IP information for IP %s: %s', remote_addr, resolved_ip_info)
|
||||
|
||||
if (resolved_ip_info and
|
||||
resolved_ip_info.country_iso_code and
|
||||
resolved_ip_info.country_iso_code in region_blacklist):
|
||||
if error_class:
|
||||
raise error_class()
|
||||
|
||||
abort(403, 'Pulls of this data have been restricted geographically')
|
||||
|
||||
return wrapped(*args, **kwargs)
|
||||
return decorated
|
||||
return wrapper
|
||||
|
|
|
@ -18,7 +18,7 @@ from data.registry_model.manifestbuilder import lookup_manifest_builder
|
|||
from digest import checksums
|
||||
from endpoints.v1 import v1_bp
|
||||
from endpoints.v1.index import ensure_namespace_enabled
|
||||
from endpoints.decorators import anon_protect
|
||||
from endpoints.decorators import anon_protect, check_region_blacklisted
|
||||
from util.http import abort, exact_abort
|
||||
from util.registry.replication import queue_storage_replication
|
||||
|
||||
|
@ -109,6 +109,7 @@ def head_image_layer(namespace, repository, image_id, headers):
|
|||
@ensure_namespace_enabled
|
||||
@require_completion
|
||||
@set_cache_headers
|
||||
@check_region_blacklisted()
|
||||
@anon_protect
|
||||
def get_image_layer(namespace, repository, image_id, headers):
|
||||
permission = ReadRepositoryPermission(namespace, repository)
|
||||
|
|
|
@ -13,11 +13,11 @@ from data.registry_model.blobuploader import (create_blob_upload, retrieve_blob_
|
|||
BlobUploadException, BlobTooLargeException,
|
||||
BlobRangeMismatchException)
|
||||
from digest import digest_tools
|
||||
from endpoints.decorators import anon_protect, parse_repository_name
|
||||
from endpoints.decorators import anon_protect, parse_repository_name, check_region_blacklisted
|
||||
from endpoints.v2 import v2_bp, require_repo_read, require_repo_write, get_input_stream
|
||||
from endpoints.v2.errors import (
|
||||
BlobUnknown, BlobUploadInvalid, BlobUploadUnknown, Unsupported, NameUnknown, LayerTooLarge,
|
||||
InvalidRequest)
|
||||
InvalidRequest, BlobDownloadGeoBlocked)
|
||||
from util.cache import cache_control
|
||||
from util.names import parse_namespace_repository
|
||||
|
||||
|
@ -65,6 +65,7 @@ def check_blob_exists(namespace_name, repo_name, digest):
|
|||
@process_registry_jwt_auth(scopes=['pull'])
|
||||
@require_repo_read
|
||||
@anon_protect
|
||||
@check_region_blacklisted(BlobDownloadGeoBlocked)
|
||||
@cache_control(max_age=31536000)
|
||||
def download_blob(namespace_name, repo_name, digest):
|
||||
# Find the blob.
|
||||
|
|
|
@ -144,3 +144,10 @@ class NamespaceDisabled(V2RegistryException):
|
|||
def __init__(self, message=None):
|
||||
message = message or 'This namespace is disabled. Please contact your system administrator.'
|
||||
super(NamespaceDisabled, self).__init__('NAMESPACE_DISABLED', message, {}, 400)
|
||||
|
||||
|
||||
class BlobDownloadGeoBlocked(V2RegistryException):
|
||||
def __init__(self, detail=None):
|
||||
message = ('The region from which you are pulling has been geo-ip blocked. ' +
|
||||
'Please contact the namespace owner.')
|
||||
super(BlobDownloadGeoBlocked, self).__init__('BLOB_DOWNLOAD_GEO_BLOCKED', message, detail, 403)
|
||||
|
|
|
@ -13,7 +13,8 @@ from auth.permissions import ReadRepositoryPermission
|
|||
from data import database
|
||||
from data import model
|
||||
from data.registry_model import registry_model
|
||||
from endpoints.decorators import anon_protect, anon_allowed, route_show_if, parse_repository_name
|
||||
from endpoints.decorators import (anon_protect, anon_allowed, route_show_if, parse_repository_name,
|
||||
check_region_blacklisted)
|
||||
from endpoints.v2.blob import BLOB_DIGEST_ROUTE
|
||||
from image.appc import AppCImageFormatter
|
||||
from image.docker import ManifestException
|
||||
|
@ -273,6 +274,7 @@ def _repo_verb_signature(namespace, repository, tag_name, verb, checker=None, **
|
|||
return make_response(signature_value)
|
||||
|
||||
|
||||
@check_region_blacklisted()
|
||||
def _repo_verb(namespace, repository, tag_name, verb, formatter, sign=False, checker=None,
|
||||
**kwargs):
|
||||
# Verify that the image exists and that we have access to it.
|
||||
|
@ -444,6 +446,7 @@ def get_squashed_tag(namespace, repository, tag):
|
|||
@verbs.route('/torrent{0}'.format(BLOB_DIGEST_ROUTE), methods=['GET'])
|
||||
@process_auth
|
||||
@parse_repository_name()
|
||||
@check_region_blacklisted(namespace_name_kwarg='namespace_name')
|
||||
def get_tag_torrent(namespace_name, repo_name, digest):
|
||||
repo = model.repository.get_repository(namespace_name, repo_name)
|
||||
repo_is_public = repo is not None and model.repository.is_repository_public(repo)
|
||||
|
|
|
@ -920,7 +920,8 @@ def populate_database(minimal=False, with_storage=False):
|
|||
model.repositoryactioncount.update_repository_score(to_count)
|
||||
|
||||
|
||||
WHITELISTED_EMPTY_MODELS = ['DeletedNamespace', 'LogEntry2', 'ManifestChild']
|
||||
WHITELISTED_EMPTY_MODELS = ['DeletedNamespace', 'LogEntry2', 'ManifestChild',
|
||||
'NamespaceGeoRestriction']
|
||||
|
||||
def find_models_missing_data():
|
||||
# As a sanity check we are going to make sure that all db tables have some data, unless explicitly
|
||||
|
|
|
@ -15,7 +15,7 @@ from flask_principal import Identity
|
|||
from app import storage
|
||||
from data.database import (close_db_filter, configure, DerivedStorageForImage, QueueItem, Image,
|
||||
TagManifest, TagManifestToManifest, Manifest, ManifestLegacyImage,
|
||||
ManifestBlob)
|
||||
ManifestBlob, NamespaceGeoRestriction, User)
|
||||
from data import model
|
||||
from data.registry_model import registry_model
|
||||
from endpoints.csrf import generate_csrf_token
|
||||
|
@ -116,6 +116,13 @@ def registry_server_executor(app):
|
|||
TagManifest.delete().execute()
|
||||
return 'OK'
|
||||
|
||||
def set_geo_block_for_namespace(namespace_name, iso_country_code):
|
||||
NamespaceGeoRestriction.create(namespace=User.get(username=namespace_name),
|
||||
description='',
|
||||
unstructured_json={},
|
||||
restricted_region_iso_code=iso_country_code)
|
||||
return 'OK'
|
||||
|
||||
executor = LiveServerExecutor()
|
||||
executor.register('generate_csrf', generate_csrf)
|
||||
executor.register('set_supports_direct_download', set_supports_direct_download)
|
||||
|
@ -130,6 +137,7 @@ def registry_server_executor(app):
|
|||
executor.register('create_app_repository', create_app_repository)
|
||||
executor.register('disable_namespace', disable_namespace)
|
||||
executor.register('delete_manifests', delete_manifests)
|
||||
executor.register('set_geo_block_for_namespace', set_geo_block_for_namespace)
|
||||
return executor
|
||||
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@ class V1ProtocolSteps(Enum):
|
|||
PUT_TAG = 'put-tag'
|
||||
PUT_IMAGE_JSON = 'put-image-json'
|
||||
DELETE_TAG = 'delete-tag'
|
||||
GET_LAYER = 'get-layer'
|
||||
|
||||
|
||||
class V1Protocol(RegistryProtocol):
|
||||
|
@ -45,6 +46,9 @@ class V1Protocol(RegistryProtocol):
|
|||
Failures.INVALID_IMAGES: 400,
|
||||
Failures.NAMESPACE_DISABLED: 400,
|
||||
},
|
||||
V1ProtocolSteps.GET_LAYER: {
|
||||
Failures.GEO_BLOCKED: 403,
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self, jwk):
|
||||
|
@ -118,7 +122,10 @@ class V1Protocol(RegistryProtocol):
|
|||
self.conduct(session, 'HEAD', image_prefix + 'layer', headers=headers)
|
||||
|
||||
# And retrieve the layer data.
|
||||
result = self.conduct(session, 'GET', image_prefix + 'layer', headers=headers)
|
||||
result = self.conduct(session, 'GET', image_prefix + 'layer', headers=headers,
|
||||
expected_status=(200, expected_failure, V1ProtocolSteps.GET_LAYER),
|
||||
options=options)
|
||||
if result.status_code == 200:
|
||||
assert result.content == images[index].bytes
|
||||
|
||||
return PullResult(manifests=None, image_ids=image_ids)
|
||||
|
|
|
@ -27,6 +27,7 @@ class V2ProtocolSteps(Enum):
|
|||
CATALOG = 'catalog'
|
||||
LIST_TAGS = 'list-tags'
|
||||
START_UPLOAD = 'start-upload'
|
||||
GET_BLOB = 'get-blob'
|
||||
|
||||
|
||||
class V2Protocol(RegistryProtocol):
|
||||
|
@ -48,6 +49,9 @@ class V2Protocol(RegistryProtocol):
|
|||
Failures.UNAUTHORIZED: 401,
|
||||
Failures.DISALLOWED_LIBRARY_NAMESPACE: 400,
|
||||
},
|
||||
V2ProtocolSteps.GET_BLOB: {
|
||||
Failures.GEO_BLOCKED: 403,
|
||||
},
|
||||
V2ProtocolSteps.BLOB_HEAD_CHECK: {
|
||||
Failures.DISALLOWED_LIBRARY_NAMESPACE: 400,
|
||||
},
|
||||
|
@ -466,6 +470,7 @@ class V2Protocol(RegistryProtocol):
|
|||
assert response.headers['Content-Length'] == str(len(blob_bytes))
|
||||
|
||||
# And retrieve the blob data.
|
||||
if not options.skip_blob_push_checks:
|
||||
result = self.conduct(session, 'GET',
|
||||
'/v2/%s/blobs/%s' % (self.repo_name(namespace, repo_name), blob_digest),
|
||||
headers=headers, expected_status=200)
|
||||
|
@ -558,8 +563,10 @@ class V2Protocol(RegistryProtocol):
|
|||
result = self.conduct(session, 'GET',
|
||||
'/v2/%s/blobs/%s' % (self.repo_name(namespace, repo_name),
|
||||
blob_digest),
|
||||
expected_status=expected_status,
|
||||
headers=headers)
|
||||
expected_status=(expected_status, expected_failure,
|
||||
V2ProtocolSteps.GET_BLOB),
|
||||
headers=headers,
|
||||
options=options)
|
||||
|
||||
if expected_status == 200:
|
||||
assert result.content == image.bytes
|
||||
|
|
|
@ -65,6 +65,7 @@ class Failures(Enum):
|
|||
INVALID_BLOB = 'invalid-blob'
|
||||
NAMESPACE_DISABLED = 'namespace-disabled'
|
||||
UNAUTHORIZED_FOR_MOUNT = 'unauthorized-for-mount'
|
||||
GEO_BLOCKED = 'geo-blocked'
|
||||
|
||||
|
||||
class ProtocolOptions(object):
|
||||
|
@ -78,6 +79,8 @@ class ProtocolOptions(object):
|
|||
self.accept_mimetypes = None
|
||||
self.mount_blobs = None
|
||||
self.push_by_manifest_digest = False
|
||||
self.request_addr = None
|
||||
self.skip_blob_push_checks = False
|
||||
|
||||
|
||||
@add_metaclass(ABCMeta)
|
||||
|
@ -115,12 +118,16 @@ class RegistryProtocol(object):
|
|||
return repo_name
|
||||
|
||||
def conduct(self, session, method, url, expected_status=200, params=None, data=None,
|
||||
json_data=None, headers=None, auth=None):
|
||||
json_data=None, headers=None, auth=None, options=None):
|
||||
if json_data is not None:
|
||||
data = json.dumps(json_data)
|
||||
headers = headers or {}
|
||||
headers['Content-Type'] = 'application/json'
|
||||
|
||||
if options and options.request_addr:
|
||||
headers = headers or {}
|
||||
headers['X-Override-Remote-Addr-For-Testing'] = options.request_addr
|
||||
|
||||
if isinstance(expected_status, tuple):
|
||||
expected_status, expected_failure, protocol_step = expected_status
|
||||
if expected_failure is not None:
|
||||
|
|
|
@ -1706,3 +1706,25 @@ def test_verify_schema2(v22_protocol, basic_images, liveserver_session, liveserv
|
|||
credentials=credentials)
|
||||
manifest = result.manifests['latest']
|
||||
assert manifest.schema_version == 2
|
||||
|
||||
|
||||
def test_geo_blocking(pusher, puller, basic_images, liveserver_session,
|
||||
liveserver, registry_server_executor, app_reloader):
|
||||
""" Test: Attempt to pull an image from a geoblocked IP address. """
|
||||
credentials = ('devtable', 'password')
|
||||
options = ProtocolOptions()
|
||||
options.skip_blob_push_checks = True # Otherwise, cache gets established.
|
||||
|
||||
# Push a new repository.
|
||||
pusher.push(liveserver_session, 'devtable', 'newrepo', 'latest', basic_images,
|
||||
credentials=credentials, options=options)
|
||||
|
||||
registry_server_executor.on(liveserver).set_geo_block_for_namespace('devtable', 'US')
|
||||
|
||||
# Attempt to pull the repository to verify. This should fail with a 403 due to
|
||||
# the geoblocking of the IP being using.
|
||||
options = ProtocolOptions()
|
||||
options.request_addr = '6.0.0.0'
|
||||
puller.pull(liveserver_session, 'devtable', 'newrepo', 'latest', basic_images,
|
||||
credentials=credentials, options=options,
|
||||
expected_failure=Failures.GEO_BLOCKED)
|
||||
|
|
|
@ -16,7 +16,8 @@ import requests
|
|||
|
||||
from util.abchelpers import nooper
|
||||
|
||||
ResolvedLocation = namedtuple('ResolvedLocation', ['provider', 'region', 'service', 'sync_token'])
|
||||
ResolvedLocation = namedtuple('ResolvedLocation', ['provider', 'region', 'service', 'sync_token',
|
||||
'country_iso_code'])
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -124,7 +125,11 @@ class IPResolver(IPResolverInterface):
|
|||
fails, returns None.
|
||||
"""
|
||||
location_function = self._get_location_function()
|
||||
if not ip_address or not location_function:
|
||||
if not ip_address:
|
||||
return None
|
||||
|
||||
if not location_function:
|
||||
logger.debug('No location function could be defined for IP address resolution')
|
||||
return None
|
||||
|
||||
return location_function(ip_address)
|
||||
|
@ -142,8 +147,9 @@ class IPResolver(IPResolverInterface):
|
|||
cache = CACHE
|
||||
sync_token = cache.get('sync_token', None)
|
||||
if sync_token is None:
|
||||
logger.debug('The aws ip range has not been cached from %s', _DATA_FILES['aws-ip-ranges.json'])
|
||||
return None
|
||||
logger.debug('The aws ip range has not been cached from %s',
|
||||
_DATA_FILES['aws-ip-ranges.json'])
|
||||
return IPResolver._build_location_function(sync_token, set(), {}, self.geoip_db)
|
||||
|
||||
all_amazon = cache['all_amazon']
|
||||
regions = cache['regions']
|
||||
|
@ -163,20 +169,25 @@ class IPResolver(IPResolverInterface):
|
|||
try:
|
||||
parsed_ip = IPAddress(ip_address)
|
||||
except AddrFormatError:
|
||||
return ResolvedLocation('invalid_ip', None, None, sync_token)
|
||||
return ResolvedLocation('invalid_ip', None, None, sync_token, None)
|
||||
|
||||
if parsed_ip not in all_amazon:
|
||||
# Try geoip classification
|
||||
try:
|
||||
found = country_db.country(parsed_ip)
|
||||
geoinfo = country_db.country(parsed_ip)
|
||||
except geoip2.errors.AddressNotFoundError:
|
||||
geoinfo = None
|
||||
|
||||
if parsed_ip not in all_amazon:
|
||||
if geoinfo:
|
||||
return ResolvedLocation(
|
||||
'internet',
|
||||
found.continent.code,
|
||||
found.country.iso_code,
|
||||
geoinfo.continent.code,
|
||||
geoinfo.country.iso_code,
|
||||
sync_token,
|
||||
geoinfo.country.iso_code,
|
||||
)
|
||||
except geoip2.errors.AddressNotFoundError:
|
||||
return ResolvedLocation('internet', None, None, sync_token)
|
||||
|
||||
return ResolvedLocation('internet', None, None, sync_token, None)
|
||||
|
||||
region = None
|
||||
|
||||
|
@ -185,7 +196,8 @@ class IPResolver(IPResolverInterface):
|
|||
region = region_name
|
||||
break
|
||||
|
||||
return ResolvedLocation('aws', region, None, sync_token)
|
||||
return ResolvedLocation('aws', region, None, sync_token,
|
||||
geoinfo.country.country_iso_code if geoinfo else None)
|
||||
return _get_location
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -44,17 +44,17 @@ def test_unstarted(app, test_aws_ip, unstarted_cache):
|
|||
ipresolver = IPResolver(app)
|
||||
|
||||
with patch.dict('util.ipresolver.CACHE', unstarted_cache):
|
||||
assert ipresolver.resolve_ip(test_aws_ip) is None
|
||||
assert ipresolver.resolve_ip(test_aws_ip) is not None
|
||||
|
||||
def test_resolved(aws_ip_range_data, test_ip_range_cache, test_aws_ip, app):
|
||||
with patch('util.ipresolver._UpdateIPRange'):
|
||||
ipresolver = IPResolver(app)
|
||||
|
||||
with patch.dict('util.ipresolver.CACHE', test_ip_range_cache):
|
||||
assert ipresolver.resolve_ip(test_aws_ip) == ResolvedLocation(provider='aws', region=u'GLOBAL', service=None, sync_token=123456789)
|
||||
assert ipresolver.resolve_ip('10.0.0.2') == ResolvedLocation(provider='aws', region=u'GLOBAL', service=None, sync_token=123456789)
|
||||
assert ipresolver.resolve_ip('1.2.3.4') == ResolvedLocation(provider='internet', region=u'NA', service=u'US', sync_token=123456789)
|
||||
assert ipresolver.resolve_ip('127.0.0.1') == ResolvedLocation(provider='internet', region=None, service=None, sync_token=123456789)
|
||||
assert ipresolver.resolve_ip(test_aws_ip) == ResolvedLocation(provider='aws', region=u'GLOBAL', service=None, sync_token=123456789, country_iso_code=None)
|
||||
assert ipresolver.resolve_ip('10.0.0.2') == ResolvedLocation(provider='aws', region=u'GLOBAL', service=None, sync_token=123456789, country_iso_code=None)
|
||||
assert ipresolver.resolve_ip('1.2.3.4') == ResolvedLocation(provider='internet', region=u'NA', service=u'US', sync_token=123456789, country_iso_code=u'US')
|
||||
assert ipresolver.resolve_ip('127.0.0.1') == ResolvedLocation(provider='internet', region=None, service=None, sync_token=123456789, country_iso_code=None)
|
||||
|
||||
def test_thread_missing_file():
|
||||
class LoopInterruptionForTest(Exception):
|
||||
|
|
Reference in a new issue