import itertools

from data import model, database
from endpoints.api.logs_models_interface import LogEntryDataInterface, LogEntryPage, LogEntry, AggregatedLogEntry


def _create_log(log):
  account_organization = None
  account_username = None
  account_email = None
  account_robot = None
  try:
    account_organization = log.account.organization
    account_username = log.account.username
    account_email = log.account.email
    account_robot = log.account.robot
  except AttributeError:
    pass

  performer_robot = None
  performer_username = None
  performer_email = None

  try:
    performer_robot = log.performer.robot
    performer_username = log.performer.username
    performer_email = log.performer.email
  except AttributeError:
    pass

  return LogEntry(log.metadata_json, log.ip, log.datetime, performer_email, performer_username,
                  performer_robot, account_organization, account_username,
                  account_email, account_robot, log.kind_id)


class PreOCIModel(LogEntryDataInterface):
  """
  PreOCIModel implements the data model for the Tags using a database schema
  before it was changed to support the OCI specification.
  """

  def get_logs_query(self, start_time, end_time, performer_name=None, repository_name=None,
                     namespace_name=None, ignore=None, page_token=None):
    repo = None
    if repository_name and namespace_name:
      repo = model.repository.get_repository(namespace_name, repository_name)

    performer = None
    if performer_name:
      performer = model.user.get_user(performer_name)
  
    # TODO(LogMigrate): Remove the branch once we're back on LogEntry only.
    def get_logs(m):
      logs_query = model.log.get_logs_query(start_time, end_time, performer=performer,
                                            repository=repo, namespace=namespace_name,
                                            ignore=ignore, model=m)

      logs, next_page_token = model.modelutil.paginate(logs_query, m,
                                                       descending=True, page_token=page_token,
                                                       limit=20)
      return LogEntryPage([_create_log(log) for log in logs], next_page_token)

    # First check the LogEntry2 table for the most recent logs, unless we've been expressly told
    # to look inside the first table.
    TOKEN_TABLE_KEY = 'ttk'
    is_old_table = page_token is not None and page_token.get(TOKEN_TABLE_KEY) == 1
    if is_old_table:
      page_result = get_logs(database.LogEntry)
    else:
      page_result = get_logs(database.LogEntry2)

    if page_result.next_page_token is None and not is_old_table:
      page_result = page_result._replace(next_page_token={TOKEN_TABLE_KEY: 1})
    elif is_old_table and page_result.next_page_token is not None:
      page_result.next_page_token[TOKEN_TABLE_KEY] = 1

    return page_result

  def get_log_entry_kinds(self):
    return model.log.get_log_entry_kinds()

  def repo_exists(self, namespace_name, repository_name):
    repo = model.repository.get_repository(namespace_name, repository_name)
    if repo is None:
      return False
    return True

  def get_aggregated_logs(self, start_time, end_time, performer_name=None, repository_name=None,
                          namespace_name=None, ignore=None):
    repo = None
    if repository_name and namespace_name:
      repo = model.repository.get_repository(namespace_name, repository_name)

    performer = None
    if performer_name:
      performer = model.user.get_user(performer_name)

    # TODO(LogMigrate): Remove the branch once we're back on LogEntry only.
    aggregated_logs = model.log.get_aggregated_logs(start_time, end_time, performer=performer,
                                                    repository=repo, namespace=namespace_name,
                                                    ignore=ignore, model=database.LogEntry)
    aggregated_logs_2 = model.log.get_aggregated_logs(start_time, end_time, performer=performer,
                                                      repository=repo, namespace=namespace_name,
                                                      ignore=ignore, model=database.LogEntry2)

    entries = {}
    for log in itertools.chain(aggregated_logs, aggregated_logs_2):
      key = '%s-%s' % (log.kind_id, log.day)
      if key in entries:
        entries[key] = AggregatedLogEntry(log.count + entries[key].count, log.kind_id, log.day)
      else:
        entries[key] = AggregatedLogEntry(log.count, log.kind_id, log.day)

    return entries.values()


pre_oci_model = PreOCIModel()