import logging
import etcd
import uuid
import calendar
import os.path
import json

from datetime import datetime, timedelta
from trollius import From, coroutine, Return, async
from concurrent.futures import ThreadPoolExecutor
from urllib3.exceptions import ReadTimeoutError, ProtocolError

from buildman.manager.basemanager import BaseManager
from buildman.manager.executor import PopenExecutor, EC2Executor
from buildman.component.buildcomponent import BuildComponent
from buildman.jobutil.buildjob import BuildJob
from buildman.asyncutil import AsyncWrapper
from util.morecollections import AttrDict


logger = logging.getLogger(__name__)


ETCD_BUILDER_PREFIX = 'building/'
ETCD_REALM_PREFIX = 'realm/'
ETCD_DISABLE_TIMEOUT = 0

class EtcdAction(object):
  GET = 'get'
  SET = 'set'
  EXPIRE = 'expire'
  UPDATE = 'update'
  DELETE = 'delete'
  CREATE = 'create'
  COMPARE_AND_SWAP = 'compareAndSwap'
  COMPARE_AND_DELETE = 'compareAndDelete'


class EphemeralBuilderManager(BaseManager):
  """ Build manager implementation for the Enterprise Registry. """
  _executors = {
      'popen': PopenExecutor,
      'ec2': EC2Executor,
  }

  _etcd_client_klass = etcd.Client

  def __init__(self, *args, **kwargs):
    self._shutting_down = False

    self._manager_config = None
    self._async_thread_executor = None
    self._etcd_client = None

    self._component_to_job = {}
    self._job_uuid_to_component = {}
    self._component_to_builder = {}

    self._executor = None

    # Map of etcd keys being watched to the tasks watching them
    self._watch_tasks = {}

    super(EphemeralBuilderManager, self).__init__(*args, **kwargs)

  def _watch_etcd(self, etcd_key, change_callback, recursive=True):
    watch_task_key = (etcd_key, recursive)
    def callback_wrapper(changed_key_future):
      if watch_task_key not in self._watch_tasks or self._watch_tasks[watch_task_key].done():
        self._watch_etcd(etcd_key, change_callback)

      if changed_key_future.cancelled():
        # Due to lack of interest, tomorrow has been cancelled
        return

      try:
        etcd_result = changed_key_future.result()
      except (ReadTimeoutError, ProtocolError):
        return

      change_callback(etcd_result)

    if not self._shutting_down:
      watch_future = self._etcd_client.watch(etcd_key, recursive=recursive,
                                             timeout=ETCD_DISABLE_TIMEOUT)
      watch_future.add_done_callback(callback_wrapper)
      logger.debug('Scheduling watch of key: %s%s', etcd_key, '/*' if recursive else '')
      self._watch_tasks[watch_task_key] = async(watch_future)

  def _handle_builder_expiration(self, etcd_result):
    if etcd_result.action == EtcdAction.EXPIRE:
      # Handle the expiration
      logger.debug('Builder expired, clean up the old build node')
      job_metadata = json.loads(etcd_result._prev_node.value)

      if 'builder_id' in job_metadata:
        logger.info('Terminating expired build node.')
        async(self._executor.stop_builder(job_metadata['builder_id']))

  def _handle_realm_change(self, etcd_result):
    if etcd_result.action == EtcdAction.CREATE:
      # We must listen on the realm created by ourselves or another worker
      realm_spec = json.loads(etcd_result.value)
      self._register_realm(realm_spec)

    elif etcd_result.action == EtcdAction.DELETE or etcd_result.action == EtcdAction.EXPIRE:
      # We must stop listening for new connections on the specified realm, if we did not get the
      # connection
      realm_spec = json.loads(etcd_result._prev_node.value)
      build_job = BuildJob(AttrDict(realm_spec['job_queue_item']))
      component = self._job_uuid_to_component.pop(build_job.job_details['build_uuid'], None)
      if component is not None:
        # We were not the manager which the worker connected to, remove the bookkeeping for it
        logger.debug('Unregistering unused component on realm: %s', realm_spec['realm'])
        del self._component_to_job[component]
        del self._component_to_builder[component]
        self.unregister_component(component)

    else:
      logger.warning('Unexpected action (%s) on realm key: %s', etcd_result.action, etcd_result.key)

  def _register_realm(self, realm_spec):
    logger.debug('Registering realm with manager: %s', realm_spec['realm'])
    component = self.register_component(realm_spec['realm'], BuildComponent,
                                        token=realm_spec['token'])
    build_job = BuildJob(AttrDict(realm_spec['job_queue_item']))
    self._component_to_job[component] = build_job
    self._component_to_builder[component] = realm_spec['builder_id']
    self._job_uuid_to_component[build_job.job_details['build_uuid']] = component

  @coroutine
  def _register_existing_realms(self):
    try:
      all_realms = yield From(self._etcd_client.read(ETCD_REALM_PREFIX, recursive=True))
      for realm in all_realms.children:
        if not realm.dir:
          self._register_realm(json.loads(realm.value))
    except KeyError:
      # no realms have been registered yet
      pass

  def initialize(self, manager_config):
    logger.debug('Calling initialize')
    self._manager_config = manager_config

    executor_klass = self._executors.get(manager_config.get('EXECUTOR', ''), PopenExecutor)
    self._executor = executor_klass(manager_config.get('EXECUTOR_CONFIG', {}),
                                    self.manager_hostname)

    etcd_host = self._manager_config.get('ETCD_HOST', '127.0.0.1')
    etcd_port = self._manager_config.get('ETCD_PORT', 2379)
    etcd_auth = self._manager_config.get('ETCD_CERT_AND_KEY', None)
    etcd_ca_cert = self._manager_config.get('ETCD_CA_CERT', None)
    etcd_protocol = 'http' if etcd_auth is None else 'https'
    logger.debug('Connecting to etcd on %s:%s', etcd_host, etcd_port)

    worker_threads = self._manager_config.get('ETCD_WORKER_THREADS', 5)
    self._async_thread_executor = ThreadPoolExecutor(worker_threads)
    self._etcd_client = AsyncWrapper(self._etcd_client_klass(host=etcd_host, port=etcd_port,
                                                             cert=etcd_auth, ca_cert=etcd_ca_cert,
                                                             protocol=etcd_protocol),
                                     executor=self._async_thread_executor)

    self._watch_etcd(ETCD_BUILDER_PREFIX, self._handle_builder_expiration)
    self._watch_etcd(ETCD_REALM_PREFIX, self._handle_realm_change)

    # Load components for all realms currently known to the cluster
    async(self._register_existing_realms())

  def setup_time(self):
    setup_time = self._manager_config.get('MACHINE_SETUP_TIME', 300)
    return setup_time

  def shutdown(self):
    logger.debug('Shutting down worker.')
    self._shutting_down = True

    for (etcd_key, _), task in self._watch_tasks.items():
      if not task.done():
        logger.debug('Canceling watch task for %s', etcd_key)
        task.cancel()

    if self._async_thread_executor is not None:
      logger.debug('Shutting down thread pool executor.')
      self._async_thread_executor.shutdown()

  @coroutine
  def schedule(self, build_job):
    build_uuid = build_job.job_details['build_uuid']
    logger.debug('Calling schedule with job: %s', build_uuid)

    # Check if there are worker slots avialable by checking the number of jobs in etcd
    allowed_worker_count = self._manager_config.get('ALLOWED_WORKER_COUNT', 1)
    try:
      building = yield From(self._etcd_client.read(ETCD_BUILDER_PREFIX, recursive=True))
      workers_alive = sum(1 for child in building.children if not child.dir)
    except KeyError:
      workers_alive = 0

    logger.debug('Total jobs: %s', workers_alive)

    if workers_alive >= allowed_worker_count:
      logger.info('Too many workers alive, unable to start new worker. %s >= %s', workers_alive,
                  allowed_worker_count)
      raise Return(False)

    job_key = self._etcd_job_key(build_job)

    # First try to take a lock for this job, meaning we will be responsible for its lifeline
    realm = str(uuid.uuid4())
    token = str(uuid.uuid4())
    ttl = self.setup_time()
    expiration = datetime.utcnow() + timedelta(seconds=ttl)

    machine_max_expiration = self._manager_config.get('MACHINE_MAX_TIME', 7200)
    max_expiration = datetime.utcnow() + timedelta(seconds=machine_max_expiration)

    payload = {
        'expiration': calendar.timegm(expiration.timetuple()),
        'max_expiration': calendar.timegm(max_expiration.timetuple()),
    }

    try:
      yield From(self._etcd_client.write(job_key, json.dumps(payload), prevExist=False, ttl=ttl))
    except KeyError:
      # The job was already taken by someone else, we are probably a retry
      logger.error('Job already exists in etcd, are timeouts misconfigured or is the queue broken?')
      raise Return(False)

    logger.debug('Starting builder with executor: %s', self._executor)
    builder_id = yield From(self._executor.start_builder(realm, token, build_uuid))

    # Store the builder in etcd associated with the job id
    payload['builder_id'] = builder_id
    yield From(self._etcd_client.write(job_key, json.dumps(payload), prevExist=True, ttl=ttl))

    # Store the realm spec which will allow any manager to accept this builder when it connects
    realm_spec = json.dumps({
        'realm': realm,
        'token': token,
        'builder_id': builder_id,
        'job_queue_item': build_job.job_item,
    })
    try:
      yield From(self._etcd_client.write(self._etcd_realm_key(realm), realm_spec, prevExist=False,
                                         ttl=ttl))
    except KeyError:
      logger.error('Realm already exists in etcd. UUID collision or something is very very wrong.')
      raise Return(False)

    raise Return(True)

  @coroutine
  def build_component_ready(self, build_component):
    try:
      # Clean up the bookkeeping for allowing any manager to take the job
      job = self._component_to_job.pop(build_component)
      del self._job_uuid_to_component[job.job_details['build_uuid']]
      yield From(self._etcd_client.delete(self._etcd_realm_key(build_component.builder_realm)))

      logger.debug('Sending build %s to newly ready component on realm %s',
                   job.job_details['build_uuid'], build_component.builder_realm)
      yield From(build_component.start_build(job))
    except KeyError:
      logger.debug('Builder is asking for more work, but work already completed')

  def build_component_disposed(self, build_component, timed_out):
    logger.debug('Calling build_component_disposed.')

    # TODO make it so that I don't have to unregister the component if it timed out
    self.unregister_component(build_component)

  @coroutine
  def job_completed(self, build_job, job_status, build_component):
    logger.debug('Calling job_completed with status: %s', job_status)

    # Kill the ephmeral builder
    yield From(self._executor.stop_builder(self._component_to_builder.pop(build_component)))

    # Release the lock in etcd
    job_key = self._etcd_job_key(build_job)
    yield From(self._etcd_client.delete(job_key))

    self.job_complete_callback(build_job, job_status)

  @coroutine
  def job_heartbeat(self, build_job):
    # Extend the deadline in etcd
    job_key = self._etcd_job_key(build_job)
    build_job_metadata_response = yield From(self._etcd_client.read(job_key))
    build_job_metadata = json.loads(build_job_metadata_response.value)

    max_expiration = datetime.utcfromtimestamp(build_job_metadata['max_expiration'])
    max_expiration_remaining = max_expiration - datetime.utcnow()
    max_expiration_sec = max(0, int(max_expiration_remaining.total_seconds()))

    ttl = min(self.heartbeat_period_sec * 2, max_expiration_sec)
    new_expiration = datetime.utcnow() + timedelta(seconds=ttl)

    payload = {
        'expiration': calendar.timegm(new_expiration.timetuple()),
        'builder_id': build_job_metadata['builder_id'],
        'max_expiration': build_job_metadata['max_expiration'],
    }

    yield From(self._etcd_client.write(job_key, json.dumps(payload), ttl=ttl))

    self.job_heartbeat_callback(build_job)

  @staticmethod
  def _etcd_job_key(build_job):
    """ Create a key which is used to track a job in etcd.
    """
    return os.path.join(ETCD_BUILDER_PREFIX, build_job.job_details['build_uuid'])

  @staticmethod
  def _etcd_realm_key(realm):
    """ Create a key which is used to track an incoming connection on a realm.
    """
    return os.path.join(ETCD_REALM_PREFIX, realm)

  def num_workers(self):
    """ Return the number of workers we're managing locally.
    """
    return len(self._component_to_builder)