import json

from calendar import timegm
from peewee import JOIN_LEFT_OUTER, fn
from datetime import datetime, timedelta
from cachetools import lru_cache

from data.database import LogEntry, LogEntryKind, User, RepositoryActionCount, db
from data.model import config, user, DataModelException

def _logs_query(selections, start_time, end_time, performer=None, repository=None, namespace=None,
                ignore=None):
  joined = (LogEntry
            .select(*selections)
            .switch(LogEntry)
            .where(LogEntry.datetime >= start_time, LogEntry.datetime < end_time))

  if repository:
    joined = joined.where(LogEntry.repository == repository)

  if performer:
    joined = joined.where(LogEntry.performer == performer)

  if namespace:
    namespace_user = user.get_user_or_org(namespace)
    if namespace_user is None:
      raise DataModelException('Invalid namespace requested')

    joined = joined.where(LogEntry.account == namespace_user.id)

  if ignore:
    kind_map = get_log_entry_kinds()
    ignore_ids = [kind_map[kind_name] for kind_name in ignore]
    joined = joined.where(~(LogEntry.kind << ignore_ids))

  return joined


@lru_cache(maxsize=1)
def get_log_entry_kinds():
  kind_map = {}
  for kind in LogEntryKind.select():
    kind_map[kind.id] = kind.name
    kind_map[kind.name] = kind.id

  return kind_map


def _get_log_entry_kind(name):
  kinds = get_log_entry_kinds()
  return kinds[name]


def get_aggregated_logs(start_time, end_time, performer=None, repository=None, namespace=None,
                        ignore=None):
  date = db.extract_date('day', LogEntry.datetime)
  selections = [LogEntry.kind, date.alias('day'), fn.Count(LogEntry.id).alias('count')]
  query = _logs_query(selections, start_time, end_time, performer, repository, namespace, ignore)
  return query.group_by(date, LogEntry.kind)


def get_logs_query(start_time, end_time, performer=None, repository=None, namespace=None,
                   ignore=None):
  Performer = User.alias()
  selections = [LogEntry, Performer]

  query = _logs_query(selections, start_time, end_time, performer, repository, namespace, ignore)
  query = (query.switch(LogEntry)
                .join(Performer, JOIN_LEFT_OUTER,
                  on=(LogEntry.performer == Performer.id).alias('performer')))

  return query


def _json_serialize(obj):
  if isinstance(obj, datetime):
    return timegm(obj.utctimetuple())

  return obj


def log_action(kind_name, user_or_organization_name, performer=None, repository=None,
               ip=None, metadata={}, timestamp=None):
  if not timestamp:
    timestamp = datetime.today()

  account = None
  if user_or_organization_name is not None:
    account = User.get(User.username == user_or_organization_name).id
  else:
    account = config.app_config.get('SERVICE_LOG_ACCOUNT_ID')
    if account is None:
      account = User.select(fn.Min(User.id)).tuples().get()[0]

  if performer is not None:
    performer = performer.id

  if repository is not None:
    repository = repository.id

  kind = _get_log_entry_kind(kind_name)
  metadata_json = json.dumps(metadata, default=_json_serialize)
  LogEntry.create(kind=kind, account=account, performer=performer,
                  repository=repository, ip=ip, metadata_json=metadata_json,
                  datetime=timestamp)


def get_stale_logs_start_id():
  """ Gets the oldest log entry. """
  try:
    return (LogEntry
            .select(LogEntry.id)
            .order_by(LogEntry.id)
            .limit(1)
            .tuples())[0][0]
  except IndexError:
    return None


def get_stale_logs_cutoff_id(cutoff_date):
  """ Gets the most recent ID created before the cutoff_date. """
  try:
    return (LogEntry
            .select(fn.Max(LogEntry.id))
            .where(LogEntry.datetime <= cutoff_date)
            .tuples())[0][0]
  except IndexError:
    return None


def get_stale_logs(start_id, end_id):
  """ Returns all the logs with IDs between start_id and end_id inclusively. """
  return LogEntry.select().where((LogEntry.id >= start_id), (LogEntry.id <= end_id))


def delete_stale_logs(start_id, end_id):
  """ Deletes all the logs with IDs between start_id and end_id. """
  LogEntry.delete().where((LogEntry.id >= start_id), (LogEntry.id <= end_id)).execute()


def get_repository_action_counts(repo, start_date):
  return RepositoryActionCount.select().where(RepositoryActionCount.repository == repo,
                                              RepositoryActionCount.date >= start_date)


def get_repositories_action_sums(repository_ids):
  if not repository_ids:
    return {}

  # Filter the join to recent entries only.
  last_week = datetime.now() - timedelta(weeks=1)
  tuples = (RepositoryActionCount
            .select(RepositoryActionCount.repository, fn.Sum(RepositoryActionCount.count))
            .where(RepositoryActionCount.repository << repository_ids)
            .where(RepositoryActionCount.date >= last_week)
            .group_by(RepositoryActionCount.repository)
            .tuples())

  action_count_map = {}
  for record in tuples:
    action_count_map[record[0]] = record[1]

  return action_count_map