from data.database import LogEntryKind, LogEntry

class assert_action_logged(object):
  """ Specialized assertion for ensuring that a log entry of a particular kind was added under the
      context of this call.
  """
  def __init__(self, log_kind):
    self.log_kind = log_kind
    self.existing_count = 0

  def _get_log_count(self):
    return LogEntry.select().where(LogEntry.kind == LogEntryKind.get(name=self.log_kind)).count()

  def __enter__(self):
    self.existing_count = self._get_log_count()
    return self

  def __exit__(self, exc_type, exc_val, exc_tb):
    if exc_val is None:
      updated_count = self._get_log_count()
      error_msg = 'Missing new log entry of kind %s' % self.log_kind
      assert self.existing_count == (updated_count - 1), error_msg