import logging
import json
import signal
import sys

from threading import Event, Lock
from apscheduler.schedulers.background import BackgroundScheduler
from datetime import datetime, timedelta
from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler
from threading import Thread
from time import sleep

from data.model import db
from data.queue import WorkQueue

logger = logging.getLogger(__name__)

class JobException(Exception):
  """ A job exception is an exception that is caused by something being malformed in the job. When
      a worker raises this exception the job will be terminated and the retry will not be returned
      to the queue. """
  pass


class WorkerUnhealthyException(Exception):
  """ When this exception is raised, the worker is no longer healthy and will not accept any more
      work. When this is raised while processing a queue item, the item should be returned to the
      queue along with another retry. """
  pass


class WorkerStatusServer(HTTPServer):
  def __init__(self, worker, *args, **kwargs):
    HTTPServer.__init__(self, *args, **kwargs)
    self.worker = worker


class WorkerStatusHandler(BaseHTTPRequestHandler):
  def do_GET(self):
    if self.path == '/status':
      # Return the worker status
      code = 200 if self.server.worker.is_healthy() else 503
      self.send_response(code)
      self.send_header('Content-Type', 'text/plain')
      self.end_headers()
      self.wfile.write('OK')
    elif self.path == '/terminate':
      # Return whether it is safe to terminate the worker process
      code = 200 if self.server.worker.is_terminated() else 503
      self.send_response(code)
    else:
      self.send_error(404)

  def do_POST(self):
    if self.path == '/terminate':
      try:
        self.server.worker.join()
        self.send_response(200)
      except:
        self.send_response(500)
    else:
      self.send_error(404)


class Worker(object):
  def __init__(self, queue, poll_period_seconds=30, reservation_seconds=300,
               watchdog_period_seconds=60, retry_after_seconds=300):
    self._sched = BackgroundScheduler()
    self._poll_period_seconds = poll_period_seconds
    self._reservation_seconds = reservation_seconds
    self._watchdog_period_seconds = watchdog_period_seconds
    self._retry_after_seconds = retry_after_seconds
    self._stop = Event()
    self._terminated = Event()
    self._queue = queue
    self._current_item_lock = Lock()
    self.current_queue_item = None

  def process_queue_item(self, job_details):
    """ Return True if complete, False if it should be retried. """
    raise NotImplementedError('Workers must implement run.')

  def watchdog(self):
    """ Function that gets run once every watchdog_period_seconds. """
    pass

  def _close_db_handle(self):
    if not db.is_closed():
      logger.debug('Disconnecting from database.')
      db.close()

  def is_healthy(self):
    return not self._stop.is_set()

  def is_terminated(self):
    return self._terminated.is_set()

  def extend_processing(self, seconds_from_now):
    with self._current_item_lock:
      if self.current_queue_item is not None:
        self._queue.extend_processing(self.current_queue_item, seconds_from_now)

  def run_watchdog(self):
    logger.debug('Running watchdog.')
    try:
      self.watchdog()
    except WorkerUnhealthyException as exc:
      logger.error('The worker has encountered an error via watchdog and will not take new jobs')
      logger.error(exc.message)
      self.mark_current_incomplete(restore_retry=True)
      self._stop.set()

  def poll_queue(self):
    logger.debug('Getting work item from queue.')

    with self._current_item_lock:
      self.current_queue_item = self._queue.get(processing_time=self._reservation_seconds)

    while True:
      # Retrieve the current item in the queue over which to operate. We do so under
      # a lock to make sure we are always retrieving an item when in a healthy state.
      current_queue_item = None
      with self._current_item_lock:
        current_queue_item = self.current_queue_item
        if current_queue_item is None:
          # Close the db handle.
          self._close_db_handle()
          break

      logger.debug('Queue gave us some work: %s', current_queue_item.body)
      job_details = json.loads(current_queue_item.body)

      try:
        self.process_queue_item(job_details)
        self.mark_current_complete()

      except JobException as jex:
        logger.warning('An error occurred processing request: %s', current_queue_item.body)
        logger.warning('Job exception: %s' % jex)
        self.mark_current_incomplete(restore_retry=False)

      except WorkerUnhealthyException as exc:
        logger.error('The worker has encountered an error via the job and will not take new jobs')
        logger.error(exc.message)
        self.mark_current_incomplete(restore_retry=True)
        self._stop.set()

      finally:
        # Close the db handle.
        self._close_db_handle()

      if not self._stop.is_set():
        with self._current_item_lock:
          self.current_queue_item = self._queue.get(processing_time=self._reservation_seconds)

    if not self._stop.is_set():
      logger.debug('No more work.')

  def update_queue_metrics(self):
    self._queue.update_metrics()

  def start(self, start_status_server_port=None):
    if start_status_server_port is not None:
      # Start a status server on a thread
      server_address = ('', start_status_server_port)
      httpd = WorkerStatusServer(self, server_address, WorkerStatusHandler)
      server_thread = Thread(target=httpd.serve_forever)
      server_thread.daemon = True
      server_thread.start()

    logger.debug("Scheduling worker.")

    soon = datetime.now() + timedelta(seconds=.001)

    self._sched.start()
    self._sched.add_job(self.poll_queue, 'interval', seconds=self._poll_period_seconds,
                        start_date=soon, max_instances=1)
    self._sched.add_job(self.update_queue_metrics, 'interval', seconds=60, start_date=soon,
                        max_instances=1)
    self._sched.add_job(self.run_watchdog, 'interval', seconds=self._watchdog_period_seconds,
                        max_instances=1)

    signal.signal(signal.SIGTERM, self.terminate)
    signal.signal(signal.SIGINT, self.terminate)

    while not self._stop.wait(1):
      pass

    logger.debug('Waiting for running tasks to complete.')
    self._sched.shutdown()
    logger.debug('Finished.')

    self._terminated.set()

    # Wait forever if we're running a server
    while start_status_server_port is not None:
      sleep(60)

  def mark_current_incomplete(self, restore_retry=False):
    with self._current_item_lock:
      if self.current_queue_item is not None:
        self._queue.incomplete(self.current_queue_item, restore_retry=restore_retry,
                               retry_after=self._retry_after_seconds)
        self.current_queue_item = None

  def mark_current_complete(self):
    with self._current_item_lock:
      if self.current_queue_item is not None:
        self._queue.complete(self.current_queue_item)
        self.current_queue_item = None

  def terminate(self, signal_num=None, stack_frame=None, graceful=False):
    if self._terminated.is_set():
      sys.exit(1)

    else:
      logger.debug('Shutting down worker.')
      self._stop.set()

      if not graceful:
        # Give back the retry that we took for this queue item so that if it were down to zero
        # retries it will still be picked up by another worker
        self.mark_current_incomplete()

  def join(self):
    self.terminate(graceful=True)