Add web handler decorators

This commit is contained in:
Tulir Asokan 2019-05-14 18:32:48 +03:00
parent 304c1b5536
commit 9bd06a3d64
7 changed files with 91 additions and 16 deletions

View file

@ -26,7 +26,7 @@ bcrypt_regex = re.compile(r"^\$2[ayb]\$.{56}$")
class Config(BaseFileConfig): class Config(BaseFileConfig):
@staticmethod @staticmethod
def _new_token() -> str: def _new_token() -> str:
return "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(64)) return "".join(random.choices(string.ascii_lowercase + string.digits, k=64))
def do_update(self, helper: ConfigUpdateHelper) -> None: def do_update(self, helper: ConfigUpdateHelper) -> None:
base = helper.base base = helper.base

View file

@ -1 +1 @@
from . import event, command from . import event, command, web

View file

@ -14,7 +14,7 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import (Union, Callable, Sequence, Pattern, Awaitable, NewType, Optional, Any, List, from typing import (Union, Callable, Sequence, Pattern, Awaitable, NewType, Optional, Any, List,
Dict, Tuple, Set) Dict, Tuple, Set, Iterable)
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import asyncio import asyncio
import functools import functools
@ -44,17 +44,17 @@ def _split_in_two(val: str, split_by: str) -> List[str]:
class CommandHandler: class CommandHandler:
def __init__(self, func: CommandHandlerFunc) -> None: def __init__(self, func: CommandHandlerFunc) -> None:
self.__mb_func__: CommandHandlerFunc = func self.__mb_func__: CommandHandlerFunc = func
self.__mb_parent__: CommandHandler = None self.__mb_parent__: Optional[CommandHandler] = None
self.__mb_subcommands__: List[CommandHandler] = [] self.__mb_subcommands__: List[CommandHandler] = []
self.__mb_arguments__: List[Argument] = [] self.__mb_arguments__: List[Argument] = []
self.__mb_help__: str = None self.__mb_help__: Optional[str] = None
self.__mb_get_name__: Callable[[], str] = None self.__mb_get_name__: Callable[[Any], str] = lambda s: "noname"
self.__mb_is_command_match__: Callable[[Any, str], bool] = self.__command_match_unset self.__mb_is_command_match__: Callable[[Any, str], bool] = self.__command_match_unset
self.__mb_require_subcommand__: bool = True self.__mb_require_subcommand__: bool = True
self.__mb_arg_fallthrough__: bool = True self.__mb_arg_fallthrough__: bool = True
self.__mb_event_handler__: bool = True self.__mb_event_handler__: bool = True
self.__mb_event_type__: EventType = EventType.ROOM_MESSAGE self.__mb_event_type__: EventType = EventType.ROOM_MESSAGE
self.__mb_msgtypes__: List[MessageType] = (MessageType.TEXT,) self.__mb_msgtypes__: Iterable[MessageType] = (MessageType.TEXT,)
self.__bound_copies__: Dict[Any, CommandHandler] = {} self.__bound_copies__: Dict[Any, CommandHandler] = {}
self.__bound_instance__: Any = None self.__bound_instance__: Any = None
@ -78,7 +78,7 @@ class CommandHandler:
return new_ch return new_ch
@staticmethod @staticmethod
def __command_match_unset(self, val: str) -> str: def __command_match_unset(self, val: str) -> bool:
raise NotImplementedError("Hmm") raise NotImplementedError("Hmm")
async def __call__(self, evt: MaubotMessageEvent, *, _existing_args: Dict[str, Any] = None, async def __call__(self, evt: MaubotMessageEvent, *, _existing_args: Dict[str, Any] = None,
@ -132,7 +132,7 @@ class CommandHandler:
except ArgumentSyntaxError as e: except ArgumentSyntaxError as e:
await evt.reply(e.message + (f"\n{self.__mb_usage__}" if e.show_usage else "")) await evt.reply(e.message + (f"\n{self.__mb_usage__}" if e.show_usage else ""))
return False, remaining_val return False, remaining_val
except ValueError as e: except ValueError:
await evt.reply(self.__mb_usage__) await evt.reply(self.__mb_usage__)
return False, remaining_val return False, remaining_val
return True, remaining_val return True, remaining_val
@ -206,7 +206,7 @@ class CommandHandler:
def new(name: PrefixType = None, *, help: str = None, aliases: AliasesType = None, def new(name: PrefixType = None, *, help: str = None, aliases: AliasesType = None,
event_type: EventType = EventType.ROOM_MESSAGE, msgtypes: List[MessageType] = None, event_type: EventType = EventType.ROOM_MESSAGE, msgtypes: Iterable[MessageType] = None,
require_subcommand: bool = True, arg_fallthrough: bool = True) -> CommandHandlerDecorator: require_subcommand: bool = True, arg_fallthrough: bool = True) -> CommandHandlerDecorator:
def decorator(func: Union[CommandHandler, CommandHandlerFunc]) -> CommandHandler: def decorator(func: Union[CommandHandler, CommandHandlerFunc]) -> CommandHandler:
if not isinstance(func, CommandHandler): if not isinstance(func, CommandHandler):

66
maubot/handlers/web.py Normal file
View file

@ -0,0 +1,66 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Callable, Any, Awaitable
from aiohttp import web, hdrs
WebHandler = Callable[[web.Request], Awaitable[web.StreamResponse]]
WebHandlerDecorator = Callable[[WebHandler], WebHandler]
def head(path: str, **kwargs: Any) -> WebHandlerDecorator:
return handle(hdrs.METH_HEAD, path, **kwargs)
def options(path: str, **kwargs: Any) -> WebHandlerDecorator:
return handle(hdrs.METH_OPTIONS, path, **kwargs)
def get(path: str, **kwargs: Any) -> WebHandlerDecorator:
return handle(hdrs.METH_GET, path, **kwargs)
def post(path: str, **kwargs: Any) -> WebHandlerDecorator:
return handle(hdrs.METH_POST, path, **kwargs)
def put(path: str, **kwargs: Any) -> WebHandlerDecorator:
return handle(hdrs.METH_PUT, path, **kwargs)
def patch(path: str, **kwargs: Any) -> WebHandlerDecorator:
return handle(hdrs.METH_PATCH, path, **kwargs)
def delete(path: str, **kwargs: Any) -> WebHandlerDecorator:
return handle(hdrs.METH_DELETE, path, **kwargs)
def view(path: str, **kwargs: Any) -> WebHandlerDecorator:
return handle(hdrs.METH_ANY, path, **kwargs)
def handle(method: str, path: str, **kwargs) -> WebHandlerDecorator:
def decorator(handler: WebHandler) -> WebHandler:
try:
handlers = getattr(handler, "__mb_web_handler__")
except AttributeError:
handlers = []
setattr(handler, "__mb_web_handler__", handlers)
handlers.append((method, path, kwargs))
return handler
return decorator

View file

@ -215,7 +215,7 @@ class ZippedPluginLoader(PluginLoader):
async def unload(self) -> None: async def unload(self) -> None:
for name, mod in list(sys.modules.items()): for name, mod in list(sys.modules.items()):
if getattr(mod, "__file__", "").startswith(self.path): if (getattr(mod, "__file__", "") or "").startswith(self.path):
del sys.modules[name] del sys.modules[name]
self._loaded = None self._loaded = None
self.log.debug(f"Unloaded plugin {self.meta.id} at {self.path}") self.log.debug(f"Unloaded plugin {self.meta.id} at {self.path}")

View file

@ -55,14 +55,23 @@ class Plugin(ABC):
async def start(self) -> None: async def start(self) -> None:
for key in dir(self): for key in dir(self):
val = getattr(self, key) val = getattr(self, key)
if hasattr(val, "__mb_event_handler__") and val.__mb_event_handler__: try:
if val.__mb_event_handler__:
self._handlers_at_startup.append((val, val.__mb_event_type__)) self._handlers_at_startup.append((val, val.__mb_event_type__))
self.client.add_event_handler(val.__mb_event_type__, val) self.client.add_event_handler(val.__mb_event_type__, val)
except AttributeError:
pass
try:
web_handlers = val.__mb_web_handler__
for method, path, kwargs in web_handlers:
self.webapp.add_route(method=method, path=path, handler=val, **kwargs)
except AttributeError:
pass
async def stop(self) -> None: async def stop(self) -> None:
for func, event_type in self._handlers_at_startup: for func, event_type in self._handlers_at_startup:
self.client.remove_event_handler(event_type, func) self.client.remove_event_handler(event_type, func)
if self.webapp: if self.webapp is not None:
self.webapp.clear() self.webapp.clear()
@classmethod @classmethod

View file

@ -39,7 +39,7 @@ class PluginWebApp(web.UrlDispatcher):
self._named_resources = {} self._named_resources = {}
self._middleware = [] self._middleware = []
async def handle(self, request: web.Request) -> web.Response: async def handle(self, request: web.Request) -> web.StreamResponse:
match_info = await self.resolve(request) match_info = await self.resolve(request)
match_info.freeze() match_info.freeze()
resp = None resp = None