diff --git a/data/cache/cache_key.py b/data/cache/cache_key.py index 924d72b12..39e4cac67 100644 --- a/data/cache/cache_key.py +++ b/data/cache/cache_key.py @@ -14,3 +14,8 @@ def for_catalog_page(auth_context_key, start_id, limit): """ Returns a cache key for a single page of a catalog lookup for an authed context. """ params = (auth_context_key or '(anon)', start_id or 0, limit or 0) return CacheKey('catalog_page__%s_%s_%s' % params, '60s') + + +def for_namespace_geo_restrictions(namespace_name): + """ Returns a cache key for the geo restrictions for a namespace """ + return CacheKey('geo_restrictions__%s' % (namespace_name), '240s') diff --git a/data/database.py b/data/database.py index 2649c048b..accb42a6e 100644 --- a/data/database.py +++ b/data/database.py @@ -504,7 +504,8 @@ class User(BaseModel): RepositoryNotification, OAuthAuthorizationCode, RepositoryActionCount, TagManifestLabel, TeamSync, RepositorySearchScore, - DeletedNamespace} | appr_classes | v22_classes | transition_classes + DeletedNamespace, + NamespaceGeoRestriction} | appr_classes | v22_classes | transition_classes delete_instance_filtered(self, User, delete_nullable, skip_transitive_deletes) @@ -525,6 +526,21 @@ class DeletedNamespace(BaseModel): queue_id = CharField(null=True, index=True) +class NamespaceGeoRestriction(BaseModel): + namespace = QuayUserField(index=True, allows_robots=False) + added = DateTimeField(default=datetime.utcnow) + description = CharField() + unstructured_json = JSONField() + restricted_region_iso_code = CharField(index=True) + + class Meta: + database = db + read_slaves = (read_slave,) + indexes = ( + (('namespace', 'restricted_region_iso_code'), True), + ) + + class UserPromptTypes(object): CONFIRM_USERNAME = 'confirm_username' ENTER_NAME = 'enter_name' diff --git a/data/migrations/versions/54492a68a3cf_add_namespacegeorestriction_table.py b/data/migrations/versions/54492a68a3cf_add_namespacegeorestriction_table.py new file mode 100644 index 000000000..38459b69d --- /dev/null +++ b/data/migrations/versions/54492a68a3cf_add_namespacegeorestriction_table.py @@ -0,0 +1,46 @@ +"""Add NamespaceGeoRestriction table + +Revision ID: 54492a68a3cf +Revises: c00a1f15968b +Create Date: 2018-12-05 15:12:14.201116 + +""" + +# revision identifiers, used by Alembic. +revision = '54492a68a3cf' +down_revision = 'c00a1f15968b' + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import mysql + +def upgrade(tables, tester): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('namespacegeorestriction', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('namespace_id', sa.Integer(), nullable=False), + sa.Column('added', sa.DateTime(), nullable=False), + sa.Column('description', sa.String(length=255), nullable=False), + sa.Column('unstructured_json', sa.Text(), nullable=False), + sa.Column('restricted_region_iso_code', sa.String(length=255), nullable=False), + sa.ForeignKeyConstraint(['namespace_id'], ['user.id'], name=op.f('fk_namespacegeorestriction_namespace_id_user')), + sa.PrimaryKeyConstraint('id', name=op.f('pk_namespacegeorestriction')) + ) + op.create_index('namespacegeorestriction_namespace_id', 'namespacegeorestriction', ['namespace_id'], unique=False) + op.create_index('namespacegeorestriction_namespace_id_restricted_region_iso_code', 'namespacegeorestriction', ['namespace_id', 'restricted_region_iso_code'], unique=True) + op.create_index('namespacegeorestriction_restricted_region_iso_code', 'namespacegeorestriction', ['restricted_region_iso_code'], unique=False) + # ### end Alembic commands ### + + tester.populate_table('namespacegeorestriction', [ + ('namespace_id', tester.TestDataType.Foreign('user')), + ('added', tester.TestDataType.DateTime), + ('description', tester.TestDataType.String), + ('unstructured_json', tester.TestDataType.JSON), + ('restricted_region_iso_code', tester.TestDataType.String), + ]) + + +def downgrade(tables, tester): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('namespacegeorestriction') + # ### end Alembic commands ### diff --git a/data/model/user.py b/data/model/user.py index d9d962a6f..8e79d3982 100644 --- a/data/model/user.py +++ b/data/model/user.py @@ -15,7 +15,7 @@ from data.database import (User, LoginService, FederatedLogin, RepositoryPermiss UserRegion, ImageStorageLocation, ServiceKeyApproval, OAuthApplication, RepositoryBuildTrigger, UserPromptKind, UserPrompt, UserPromptTypes, DeletedNamespace, - RobotAccountMetadata) + RobotAccountMetadata, NamespaceGeoRestriction) from data.model import (DataModelException, InvalidPasswordException, InvalidRobotException, InvalidUsernameException, InvalidEmailAddressException, TooManyLoginAttemptsException, db_transaction, @@ -1060,6 +1060,14 @@ def get_federated_logins(user_ids, service_name): LoginService.name == service_name)) +def list_namespace_geo_restrictions(namespace_name): + """ Returns all of the defined geographic restrictions for the given namespace. """ + return (NamespaceGeoRestriction + .select() + .join(User) + .where(User.username == namespace_name)) + + class LoginWrappedDBUser(UserMixin): def __init__(self, user_uuid, db_user=None): self._uuid = user_uuid diff --git a/data/registry_model/interface.py b/data/registry_model/interface.py index ce9f00e88..83f3394f7 100644 --- a/data/registry_model/interface.py +++ b/data/registry_model/interface.py @@ -316,3 +316,9 @@ class RegistryDataInterface(object): """ Creates a manifest under the repository and sets a temporary tag to point to it. Returns the manifest object created or None on error. """ + + @abstractmethod + def get_cached_namespace_region_blacklist(self, model_cache, namespace_name): + """ Returns a cached set of ISO country codes blacklisted for pulls for the namespace + or None if the list could not be loaded. + """ diff --git a/data/registry_model/shared.py b/data/registry_model/shared.py index 015b8df6d..8b62d86fd 100644 --- a/data/registry_model/shared.py +++ b/data/registry_model/shared.py @@ -121,6 +121,27 @@ class SharedModel: torrent_info = model.storage.save_torrent_info(image_storage, piece_length, pieces) return TorrentInfo.for_torrent_info(torrent_info) + + def get_cached_namespace_region_blacklist(self, model_cache, namespace_name): + """ Returns a cached set of ISO country codes blacklisted for pulls for the namespace + or None if the list could not be loaded. + """ + + def load_blacklist(): + restrictions = model.user.list_namespace_geo_restrictions(namespace_name) + if restrictions is None: + return None + + return [restriction.restricted_region_iso_code for restriction in restrictions] + + blacklist_cache_key = cache_key.for_namespace_geo_restrictions(namespace_name) + result = model_cache.retrieve(blacklist_cache_key, load_blacklist) + if result is None: + return None + + return set(result) + + def get_cached_repo_blob(self, model_cache, namespace_name, repo_name, blob_digest): """ Returns the blob in the repository with the given digest if any or None if none. diff --git a/endpoints/appr/registry.py b/endpoints/appr/registry.py index dae4f1799..b97ccb2d6 100644 --- a/endpoints/appr/registry.py +++ b/endpoints/appr/registry.py @@ -18,7 +18,7 @@ from endpoints.appr import appr_bp, require_app_repo_read, require_app_repo_writ from endpoints.appr.cnr_backend import Blob, Channel, Package, User from endpoints.appr.decorators import disallow_for_image_repository from endpoints.appr.models_cnr import model -from endpoints.decorators import anon_allowed, anon_protect +from endpoints.decorators import anon_allowed, anon_protect, check_region_blacklisted from util.names import REPOSITORY_NAME_REGEX, TAG_REGEX logger = logging.getLogger(__name__) @@ -71,6 +71,7 @@ def login(): strict_slashes=False,) @process_auth @require_app_repo_read +@check_region_blacklisted(namespace_name_kwarg='namespace') @anon_protect def blobs(namespace, package_name, digest): reponame = repo_name(namespace, package_name) @@ -114,6 +115,7 @@ def delete_package(namespace, package_name, release, media_type): methods=['GET'], strict_slashes=False) @process_auth @require_app_repo_read +@check_region_blacklisted(namespace_name_kwarg='namespace') @anon_protect def show_package(namespace, package_name, release, media_type): reponame = repo_name(namespace, package_name) @@ -152,6 +154,7 @@ def show_package_release_manifests(namespace, package_name, release): strict_slashes=False,) @process_auth @require_app_repo_read +@check_region_blacklisted(namespace_name_kwarg='namespace') @anon_protect def pull(namespace, package_name, release, media_type): logger.debug('Pull of release %s of app repository %s/%s', release, namespace, package_name) diff --git a/endpoints/decorators.py b/endpoints/decorators.py index bd16d1a8a..87ed393e5 100644 --- a/endpoints/decorators.py +++ b/endpoints/decorators.py @@ -1,5 +1,6 @@ """ Various decorators for endpoint and API handlers. """ +import os import logging from functools import wraps @@ -7,8 +8,9 @@ from flask import abort, request, make_response import features -from app import app +from app import app, ip_resolver, model_cache from auth.auth_context import get_authenticated_context +from data.registry_model import registry_model from util.names import parse_namespace_repository, ImplicitLibraryNamespaceNotAllowed from util.http import abort @@ -122,3 +124,40 @@ def require_xhr_from_browser(func): return func(*args, **kwargs) return wrapper + + +def check_region_blacklisted(error_class=None, namespace_name_kwarg=None): + """ Decorator which checks if the incoming request is from a region geo IP blocked + for the current namespace. The first argument to the wrapped function must be + the namespace name. + """ + def wrapper(wrapped): + @wraps(wrapped) + def decorated(*args, **kwargs): + if namespace_name_kwarg: + namespace_name = kwargs[namespace_name_kwarg] + else: + namespace_name = args[0] + + region_blacklist = registry_model.get_cached_namespace_region_blacklist(model_cache, + namespace_name) + if region_blacklist: + # Resolve the IP information and block if on the namespace's blacklist. + remote_addr = request.remote_addr + if os.getenv('TEST', 'false').lower() == 'true': + remote_addr = request.headers.get('X-Override-Remote-Addr-For-Testing', remote_addr) + + resolved_ip_info = ip_resolver.resolve_ip(remote_addr) + logger.debug('Resolved IP information for IP %s: %s', remote_addr, resolved_ip_info) + + if (resolved_ip_info and + resolved_ip_info.country_iso_code and + resolved_ip_info.country_iso_code in region_blacklist): + if error_class: + raise error_class() + + abort(403, 'Pulls of this data have been restricted geographically') + + return wrapped(*args, **kwargs) + return decorated + return wrapper diff --git a/endpoints/v1/registry.py b/endpoints/v1/registry.py index 08d99313d..ff5db6bf5 100644 --- a/endpoints/v1/registry.py +++ b/endpoints/v1/registry.py @@ -18,7 +18,7 @@ from data.registry_model.manifestbuilder import lookup_manifest_builder from digest import checksums from endpoints.v1 import v1_bp from endpoints.v1.index import ensure_namespace_enabled -from endpoints.decorators import anon_protect +from endpoints.decorators import anon_protect, check_region_blacklisted from util.http import abort, exact_abort from util.registry.replication import queue_storage_replication @@ -109,6 +109,7 @@ def head_image_layer(namespace, repository, image_id, headers): @ensure_namespace_enabled @require_completion @set_cache_headers +@check_region_blacklisted() @anon_protect def get_image_layer(namespace, repository, image_id, headers): permission = ReadRepositoryPermission(namespace, repository) diff --git a/endpoints/v2/blob.py b/endpoints/v2/blob.py index 0bfbe7b52..99f5ef0e8 100644 --- a/endpoints/v2/blob.py +++ b/endpoints/v2/blob.py @@ -13,11 +13,11 @@ from data.registry_model.blobuploader import (create_blob_upload, retrieve_blob_ BlobUploadException, BlobTooLargeException, BlobRangeMismatchException) from digest import digest_tools -from endpoints.decorators import anon_protect, parse_repository_name +from endpoints.decorators import anon_protect, parse_repository_name, check_region_blacklisted from endpoints.v2 import v2_bp, require_repo_read, require_repo_write, get_input_stream from endpoints.v2.errors import ( BlobUnknown, BlobUploadInvalid, BlobUploadUnknown, Unsupported, NameUnknown, LayerTooLarge, - InvalidRequest) + InvalidRequest, BlobDownloadGeoBlocked) from util.cache import cache_control from util.names import parse_namespace_repository @@ -65,6 +65,7 @@ def check_blob_exists(namespace_name, repo_name, digest): @process_registry_jwt_auth(scopes=['pull']) @require_repo_read @anon_protect +@check_region_blacklisted(BlobDownloadGeoBlocked) @cache_control(max_age=31536000) def download_blob(namespace_name, repo_name, digest): # Find the blob. diff --git a/endpoints/v2/errors.py b/endpoints/v2/errors.py index aad1bae4b..40ac46ef9 100644 --- a/endpoints/v2/errors.py +++ b/endpoints/v2/errors.py @@ -144,3 +144,10 @@ class NamespaceDisabled(V2RegistryException): def __init__(self, message=None): message = message or 'This namespace is disabled. Please contact your system administrator.' super(NamespaceDisabled, self).__init__('NAMESPACE_DISABLED', message, {}, 400) + + +class BlobDownloadGeoBlocked(V2RegistryException): + def __init__(self, detail=None): + message = ('The region from which you are pulling has been geo-ip blocked. ' + + 'Please contact the namespace owner.') + super(BlobDownloadGeoBlocked, self).__init__('BLOB_DOWNLOAD_GEO_BLOCKED', message, detail, 403) diff --git a/endpoints/verbs/__init__.py b/endpoints/verbs/__init__.py index cb1266b76..f69ac2506 100644 --- a/endpoints/verbs/__init__.py +++ b/endpoints/verbs/__init__.py @@ -13,7 +13,8 @@ from auth.permissions import ReadRepositoryPermission from data import database from data import model from data.registry_model import registry_model -from endpoints.decorators import anon_protect, anon_allowed, route_show_if, parse_repository_name +from endpoints.decorators import (anon_protect, anon_allowed, route_show_if, parse_repository_name, + check_region_blacklisted) from endpoints.v2.blob import BLOB_DIGEST_ROUTE from image.appc import AppCImageFormatter from image.docker import ManifestException @@ -273,6 +274,7 @@ def _repo_verb_signature(namespace, repository, tag_name, verb, checker=None, ** return make_response(signature_value) +@check_region_blacklisted() def _repo_verb(namespace, repository, tag_name, verb, formatter, sign=False, checker=None, **kwargs): # Verify that the image exists and that we have access to it. @@ -444,6 +446,7 @@ def get_squashed_tag(namespace, repository, tag): @verbs.route('/torrent{0}'.format(BLOB_DIGEST_ROUTE), methods=['GET']) @process_auth @parse_repository_name() +@check_region_blacklisted(namespace_name_kwarg='namespace_name') def get_tag_torrent(namespace_name, repo_name, digest): repo = model.repository.get_repository(namespace_name, repo_name) repo_is_public = repo is not None and model.repository.is_repository_public(repo) diff --git a/initdb.py b/initdb.py index ed17005db..3b9358980 100644 --- a/initdb.py +++ b/initdb.py @@ -920,7 +920,8 @@ def populate_database(minimal=False, with_storage=False): model.repositoryactioncount.update_repository_score(to_count) -WHITELISTED_EMPTY_MODELS = ['DeletedNamespace', 'LogEntry2', 'ManifestChild'] +WHITELISTED_EMPTY_MODELS = ['DeletedNamespace', 'LogEntry2', 'ManifestChild', + 'NamespaceGeoRestriction'] def find_models_missing_data(): # As a sanity check we are going to make sure that all db tables have some data, unless explicitly diff --git a/test/registry/fixtures.py b/test/registry/fixtures.py index 493c8c65a..15dd00887 100644 --- a/test/registry/fixtures.py +++ b/test/registry/fixtures.py @@ -15,7 +15,7 @@ from flask_principal import Identity from app import storage from data.database import (close_db_filter, configure, DerivedStorageForImage, QueueItem, Image, TagManifest, TagManifestToManifest, Manifest, ManifestLegacyImage, - ManifestBlob) + ManifestBlob, NamespaceGeoRestriction, User) from data import model from data.registry_model import registry_model from endpoints.csrf import generate_csrf_token @@ -116,6 +116,13 @@ def registry_server_executor(app): TagManifest.delete().execute() return 'OK' + def set_geo_block_for_namespace(namespace_name, iso_country_code): + NamespaceGeoRestriction.create(namespace=User.get(username=namespace_name), + description='', + unstructured_json={}, + restricted_region_iso_code=iso_country_code) + return 'OK' + executor = LiveServerExecutor() executor.register('generate_csrf', generate_csrf) executor.register('set_supports_direct_download', set_supports_direct_download) @@ -130,6 +137,7 @@ def registry_server_executor(app): executor.register('create_app_repository', create_app_repository) executor.register('disable_namespace', disable_namespace) executor.register('delete_manifests', delete_manifests) + executor.register('set_geo_block_for_namespace', set_geo_block_for_namespace) return executor diff --git a/test/registry/protocol_v1.py b/test/registry/protocol_v1.py index acd86547d..67849fc9d 100644 --- a/test/registry/protocol_v1.py +++ b/test/registry/protocol_v1.py @@ -15,6 +15,7 @@ class V1ProtocolSteps(Enum): PUT_TAG = 'put-tag' PUT_IMAGE_JSON = 'put-image-json' DELETE_TAG = 'delete-tag' + GET_LAYER = 'get-layer' class V1Protocol(RegistryProtocol): @@ -45,6 +46,9 @@ class V1Protocol(RegistryProtocol): Failures.INVALID_IMAGES: 400, Failures.NAMESPACE_DISABLED: 400, }, + V1ProtocolSteps.GET_LAYER: { + Failures.GEO_BLOCKED: 403, + }, } def __init__(self, jwk): @@ -118,8 +122,11 @@ class V1Protocol(RegistryProtocol): self.conduct(session, 'HEAD', image_prefix + 'layer', headers=headers) # And retrieve the layer data. - result = self.conduct(session, 'GET', image_prefix + 'layer', headers=headers) - assert result.content == images[index].bytes + result = self.conduct(session, 'GET', image_prefix + 'layer', headers=headers, + expected_status=(200, expected_failure, V1ProtocolSteps.GET_LAYER), + options=options) + if result.status_code == 200: + assert result.content == images[index].bytes return PullResult(manifests=None, image_ids=image_ids) diff --git a/test/registry/protocol_v2.py b/test/registry/protocol_v2.py index 4f9c53277..e449f8bdf 100644 --- a/test/registry/protocol_v2.py +++ b/test/registry/protocol_v2.py @@ -27,6 +27,7 @@ class V2ProtocolSteps(Enum): CATALOG = 'catalog' LIST_TAGS = 'list-tags' START_UPLOAD = 'start-upload' + GET_BLOB = 'get-blob' class V2Protocol(RegistryProtocol): @@ -48,6 +49,9 @@ class V2Protocol(RegistryProtocol): Failures.UNAUTHORIZED: 401, Failures.DISALLOWED_LIBRARY_NAMESPACE: 400, }, + V2ProtocolSteps.GET_BLOB: { + Failures.GEO_BLOCKED: 403, + }, V2ProtocolSteps.BLOB_HEAD_CHECK: { Failures.DISALLOWED_LIBRARY_NAMESPACE: 400, }, @@ -466,10 +470,11 @@ class V2Protocol(RegistryProtocol): assert response.headers['Content-Length'] == str(len(blob_bytes)) # And retrieve the blob data. - result = self.conduct(session, 'GET', - '/v2/%s/blobs/%s' % (self.repo_name(namespace, repo_name), blob_digest), - headers=headers, expected_status=200) - assert result.content == blob_bytes + if not options.skip_blob_push_checks: + result = self.conduct(session, 'GET', + '/v2/%s/blobs/%s' % (self.repo_name(namespace, repo_name), blob_digest), + headers=headers, expected_status=200) + assert result.content == blob_bytes return True @@ -558,8 +563,10 @@ class V2Protocol(RegistryProtocol): result = self.conduct(session, 'GET', '/v2/%s/blobs/%s' % (self.repo_name(namespace, repo_name), blob_digest), - expected_status=expected_status, - headers=headers) + expected_status=(expected_status, expected_failure, + V2ProtocolSteps.GET_BLOB), + headers=headers, + options=options) if expected_status == 200: assert result.content == image.bytes diff --git a/test/registry/protocols.py b/test/registry/protocols.py index fa68b47d9..f3a834e19 100644 --- a/test/registry/protocols.py +++ b/test/registry/protocols.py @@ -65,6 +65,7 @@ class Failures(Enum): INVALID_BLOB = 'invalid-blob' NAMESPACE_DISABLED = 'namespace-disabled' UNAUTHORIZED_FOR_MOUNT = 'unauthorized-for-mount' + GEO_BLOCKED = 'geo-blocked' class ProtocolOptions(object): @@ -78,6 +79,8 @@ class ProtocolOptions(object): self.accept_mimetypes = None self.mount_blobs = None self.push_by_manifest_digest = False + self.request_addr = None + self.skip_blob_push_checks = False @add_metaclass(ABCMeta) @@ -115,12 +118,16 @@ class RegistryProtocol(object): return repo_name def conduct(self, session, method, url, expected_status=200, params=None, data=None, - json_data=None, headers=None, auth=None): + json_data=None, headers=None, auth=None, options=None): if json_data is not None: data = json.dumps(json_data) headers = headers or {} headers['Content-Type'] = 'application/json' + if options and options.request_addr: + headers = headers or {} + headers['X-Override-Remote-Addr-For-Testing'] = options.request_addr + if isinstance(expected_status, tuple): expected_status, expected_failure, protocol_step = expected_status if expected_failure is not None: diff --git a/test/registry/registry_tests.py b/test/registry/registry_tests.py index 2f75ff845..332bfd1e2 100644 --- a/test/registry/registry_tests.py +++ b/test/registry/registry_tests.py @@ -1706,3 +1706,25 @@ def test_verify_schema2(v22_protocol, basic_images, liveserver_session, liveserv credentials=credentials) manifest = result.manifests['latest'] assert manifest.schema_version == 2 + + +def test_geo_blocking(pusher, puller, basic_images, liveserver_session, + liveserver, registry_server_executor, app_reloader): + """ Test: Attempt to pull an image from a geoblocked IP address. """ + credentials = ('devtable', 'password') + options = ProtocolOptions() + options.skip_blob_push_checks = True # Otherwise, cache gets established. + + # Push a new repository. + pusher.push(liveserver_session, 'devtable', 'newrepo', 'latest', basic_images, + credentials=credentials, options=options) + + registry_server_executor.on(liveserver).set_geo_block_for_namespace('devtable', 'US') + + # Attempt to pull the repository to verify. This should fail with a 403 due to + # the geoblocking of the IP being using. + options = ProtocolOptions() + options.request_addr = '6.0.0.0' + puller.pull(liveserver_session, 'devtable', 'newrepo', 'latest', basic_images, + credentials=credentials, options=options, + expected_failure=Failures.GEO_BLOCKED) diff --git a/util/ipresolver/__init__.py b/util/ipresolver/__init__.py index 9227d1d6f..fedcff5fa 100644 --- a/util/ipresolver/__init__.py +++ b/util/ipresolver/__init__.py @@ -16,7 +16,8 @@ import requests from util.abchelpers import nooper -ResolvedLocation = namedtuple('ResolvedLocation', ['provider', 'region', 'service', 'sync_token']) +ResolvedLocation = namedtuple('ResolvedLocation', ['provider', 'region', 'service', 'sync_token', + 'country_iso_code']) logger = logging.getLogger(__name__) @@ -124,7 +125,11 @@ class IPResolver(IPResolverInterface): fails, returns None. """ location_function = self._get_location_function() - if not ip_address or not location_function: + if not ip_address: + return None + + if not location_function: + logger.debug('No location function could be defined for IP address resolution') return None return location_function(ip_address) @@ -142,8 +147,9 @@ class IPResolver(IPResolverInterface): cache = CACHE sync_token = cache.get('sync_token', None) if sync_token is None: - logger.debug('The aws ip range has not been cached from %s', _DATA_FILES['aws-ip-ranges.json']) - return None + logger.debug('The aws ip range has not been cached from %s', + _DATA_FILES['aws-ip-ranges.json']) + return IPResolver._build_location_function(sync_token, set(), {}, self.geoip_db) all_amazon = cache['all_amazon'] regions = cache['regions'] @@ -163,20 +169,25 @@ class IPResolver(IPResolverInterface): try: parsed_ip = IPAddress(ip_address) except AddrFormatError: - return ResolvedLocation('invalid_ip', None, None, sync_token) + return ResolvedLocation('invalid_ip', None, None, sync_token, None) + + # Try geoip classification + try: + geoinfo = country_db.country(parsed_ip) + except geoip2.errors.AddressNotFoundError: + geoinfo = None if parsed_ip not in all_amazon: - # Try geoip classification - try: - found = country_db.country(parsed_ip) + if geoinfo: return ResolvedLocation( 'internet', - found.continent.code, - found.country.iso_code, + geoinfo.continent.code, + geoinfo.country.iso_code, sync_token, + geoinfo.country.iso_code, ) - except geoip2.errors.AddressNotFoundError: - return ResolvedLocation('internet', None, None, sync_token) + + return ResolvedLocation('internet', None, None, sync_token, None) region = None @@ -185,7 +196,8 @@ class IPResolver(IPResolverInterface): region = region_name break - return ResolvedLocation('aws', region, None, sync_token) + return ResolvedLocation('aws', region, None, sync_token, + geoinfo.country.country_iso_code if geoinfo else None) return _get_location @staticmethod diff --git a/util/ipresolver/test/test_ipresolver.py b/util/ipresolver/test/test_ipresolver.py index 19e8283a0..705df89a9 100644 --- a/util/ipresolver/test/test_ipresolver.py +++ b/util/ipresolver/test/test_ipresolver.py @@ -44,17 +44,17 @@ def test_unstarted(app, test_aws_ip, unstarted_cache): ipresolver = IPResolver(app) with patch.dict('util.ipresolver.CACHE', unstarted_cache): - assert ipresolver.resolve_ip(test_aws_ip) is None + assert ipresolver.resolve_ip(test_aws_ip) is not None def test_resolved(aws_ip_range_data, test_ip_range_cache, test_aws_ip, app): with patch('util.ipresolver._UpdateIPRange'): ipresolver = IPResolver(app) with patch.dict('util.ipresolver.CACHE', test_ip_range_cache): - assert ipresolver.resolve_ip(test_aws_ip) == ResolvedLocation(provider='aws', region=u'GLOBAL', service=None, sync_token=123456789) - assert ipresolver.resolve_ip('10.0.0.2') == ResolvedLocation(provider='aws', region=u'GLOBAL', service=None, sync_token=123456789) - assert ipresolver.resolve_ip('1.2.3.4') == ResolvedLocation(provider='internet', region=u'NA', service=u'US', sync_token=123456789) - assert ipresolver.resolve_ip('127.0.0.1') == ResolvedLocation(provider='internet', region=None, service=None, sync_token=123456789) + assert ipresolver.resolve_ip(test_aws_ip) == ResolvedLocation(provider='aws', region=u'GLOBAL', service=None, sync_token=123456789, country_iso_code=None) + assert ipresolver.resolve_ip('10.0.0.2') == ResolvedLocation(provider='aws', region=u'GLOBAL', service=None, sync_token=123456789, country_iso_code=None) + assert ipresolver.resolve_ip('1.2.3.4') == ResolvedLocation(provider='internet', region=u'NA', service=u'US', sync_token=123456789, country_iso_code=u'US') + assert ipresolver.resolve_ip('127.0.0.1') == ResolvedLocation(provider='internet', region=None, service=None, sync_token=123456789, country_iso_code=None) def test_thread_missing_file(): class LoopInterruptionForTest(Exception):