import logging
import json
import requests

from abc import ABCMeta, abstractmethod
from six import add_metaclass

from cachetools import ttl_cache, lru_cache
from collections import namedtuple, defaultdict
from netaddr import IPNetwork, IPAddress, IPSet, AddrFormatError

import geoip2.database
import geoip2.errors

from util.abchelpers import nooper

ResolvedLocation = namedtuple('ResolvedLocation', ['provider', 'region', 'service', 'sync_token'])

logger = logging.getLogger(__name__)

_DATA_FILES = {'aws-ip-ranges.json': 'https://ip-ranges.amazonaws.com/ip-ranges.json'}

def update_resolver_datafiles():
  """ Performs an update of the data file(s) used by the IP Resolver. """
  for filename, url in _DATA_FILES.iteritems():
    logger.debug('Updating IP resolver data file "%s" from URL "%s"', filename, url)
    with open('util/ipresolver/%s' % filename, 'w') as f:
      response = requests.get(url)
      logger.debug('Got %s response for URL %s', response.status_code, url)
      if response.status_code / 2 != 100:
        raise Exception('Got non-2XX status code for URL %s: %s' % (url, response.status_code))

      f.write(response.text)
      logger.debug('Successfully wrote %s', filename)


@add_metaclass(ABCMeta)
class IPResolverInterface(object):
  """ Helper class for resolving information about an IP address. """
  @abstractmethod
  def resolve_ip(self, ip_address):
    """ Attempts to return resolved information about the specified IP Address. If such an attempt
        fails, returns None.
    """
    pass


@nooper
class NoopIPResolver(IPResolverInterface):
  """ No-op version of the security scanner API. """
  pass


class IPResolver(IPResolverInterface):
  def __init__(self, app):
    self.app = app
    self.geoip_db = geoip2.database.Reader('util/ipresolver/GeoLite2-Country.mmdb')

  def resolve_ip(self, ip_address):
    """ Attempts to return resolved information about the specified IP Address. If such an attempt fails,
        returns None.
    """
    location_function = self._get_location_function()
    if not ip_address or not location_function:
      return None
    
    return location_function(ip_address)

  def _get_aws_ip_ranges(self):
    try:
      with open('util/ipresolver/aws-ip-ranges.json', 'r') as f:
        return json.loads(f.read())
    except IOError:
      logger.exception('Could not load AWS IP Ranges')
      return None
    except ValueError:
      logger.exception('Could not load AWS IP Ranges')
      return None
    except TypeError:
      logger.exception('Could not load AWS IP Ranges')
      return None
    
  @ttl_cache(maxsize=1, ttl=600)
  def _get_location_function(self):
    aws_ip_range_json = self._get_aws_ip_ranges()
    if aws_ip_range_json is None:
      return None

    sync_token = aws_ip_range_json['syncToken']
    all_amazon, regions, services = IPResolver._parse_amazon_ranges(aws_ip_range_json)
    return IPResolver._build_location_function(sync_token, all_amazon, regions, services, self.geoip_db)

  @staticmethod
  def _build_location_function(sync_token, all_amazon, regions, country, country_db):
    @lru_cache(maxsize=4096)
    def _get_location(ip_address):
      try:
        parsed_ip = IPAddress(ip_address)
      except AddrFormatError:
        return ResolvedLocation('invalid_ip', None, None, sync_token)

      if parsed_ip not in all_amazon:
        # Try geoip classification
        try:
          found = country_db.country(parsed_ip)
          return ResolvedLocation(
            'internet',
            found.continent.code,
            found.country.iso_code,
            sync_token,
          )
        except geoip2.errors.AddressNotFoundError:
          return ResolvedLocation('internet', None, None, sync_token)

      region = None

      for region_name, region_set in regions.items():
        if parsed_ip in region_set:
          region = region_name
          break

      return ResolvedLocation('aws', region, None, sync_token)
    return _get_location

  @staticmethod
  def _parse_amazon_ranges(ranges):
    all_amazon = IPSet()
    regions = defaultdict(IPSet)
    services = defaultdict(IPSet)

    for service_description in ranges['prefixes']:
      cidr = IPNetwork(service_description['ip_prefix'])
      service = service_description['service']
      region = service_description['region']

      all_amazon.add(cidr)
      regions[region].add(cidr)
      services[service].add(cidr)

    return all_amazon, regions, services