import unittest

from endpoints.api import api
from app import app
from initdb import setup_database_for_testing, finished_database_for_testing
from specs import build_specs


app.register_blueprint(api, url_prefix='/api')


NO_ACCESS_USER = 'freshuser'
READ_ACCESS_USER = 'reader'
ADMIN_ACCESS_USER = 'devtable'


class ApiTestCase(unittest.TestCase):
  def setUp(self):
    setup_database_for_testing(self)

  def tearDown(self):
    finished_database_for_testing(self)


class _SpecTestBuilder(type):
  @staticmethod
  def _test_generator(url, expected_status, open_kwargs, auth_username=None):
    def test(self):
      with app.test_client() as c:
        if auth_username:
          # Temporarily remove the teardown functions
          teardown_funcs = []
          if None in app.teardown_request_funcs:
            teardown_funcs = app.teardown_request_funcs[None]
            app.teardown_request_funcs[None] = []

          with c.session_transaction() as sess:
            sess['user_id'] = auth_username
            sess['identity.id'] = auth_username
            sess['identity.auth_type'] = 'username'

          # Restore the teardown functions
          app.teardown_request_funcs[None] = teardown_funcs

        rv = c.open(url, **open_kwargs)
        msg = '%s %s: %s expected: %s' % (open_kwargs['method'], url,
                                          rv.status_code, expected_status)
        self.assertEqual(rv.status_code, expected_status, msg)   
    return test


  def __new__(cls, name, bases, attrs):
    with app.test_request_context() as ctx:
      specs = attrs['spec_func']()
      for test_spec in specs:
        url, open_kwargs = test_spec.get_client_args()
        expected_status = getattr(test_spec, attrs['result_attr'])
        test = _SpecTestBuilder._test_generator(url, expected_status,
                                                open_kwargs,
                                                attrs['auth_username'])

        test_name_url = url.replace('/', '_').replace('-', '_')
        test_name = 'test_%s_%s' % (open_kwargs['method'].lower(),
                                    test_name_url)
        attrs[test_name] = test

    return type(name, bases, attrs)


class TestAnonymousAccess(ApiTestCase):
  __metaclass__ = _SpecTestBuilder
  spec_func = build_specs
  result_attr = 'anon_code'
  auth_username = None


class TestNoAccess(ApiTestCase):
  __metaclass__ = _SpecTestBuilder
  spec_func = build_specs
  result_attr = 'no_access_code'
  auth_username = NO_ACCESS_USER


class TestReadAccess(ApiTestCase):
  __metaclass__ = _SpecTestBuilder
  spec_func = build_specs
  result_attr = 'read_code'
  auth_username = READ_ACCESS_USER


class TestAdminAccess(ApiTestCase):
  __metaclass__ = _SpecTestBuilder
  spec_func = build_specs
  result_attr = 'admin_code'
  auth_username = ADMIN_ACCESS_USER


if __name__ == '__main__':
  unittest.main()