Some sort of command handling system

This commit is contained in:
Tulir Asokan 2018-12-18 00:53:39 +02:00
parent f104595217
commit 682eab348d
5 changed files with 107 additions and 27 deletions

View file

@ -1,3 +1,2 @@
from .plugin_base import Plugin
from .matrix import MaubotMatrixClient as Client, MaubotMessageEvent as MessageEvent
from .handlers import event, command

View file

@ -13,3 +13,89 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Union, Callable, Sequence, Pattern, Awaitable, NewType, Optional, Any
import functools
import re
from mautrix.client import EventHandler
from mautrix.types import MessageType
from ..matrix import MaubotMessageEvent
from .event import EventHandlerDecorator
PrefixType = Union[str, Callable[[], str]]
CommandDecorator = Callable[[PrefixType, str], EventHandlerDecorator]
def _get_subcommand_decorator(parent: EventHandler) -> CommandDecorator:
def subcommand(name: PrefixType, help: str = None) -> EventHandlerDecorator:
cmd_decorator = new(name=f"{parent.__mb_name__} {name}", help=help)
def decorator(func: EventHandler) -> EventHandler:
func = cmd_decorator(func)
parent.__mb_subcommands__.append(func)
return func
return decorator
return subcommand
def new(name: Union[str, Callable[[], str]], help: str = None) -> EventHandlerDecorator:
def decorator(func: EventHandler) -> EventHandler:
func.__mb_subcommands__ = []
func.__mb_help__ = help
func.__mb_name__ = name or func.__name__
func.subcommand = _get_subcommand_decorator(func)
return func
return decorator
PassiveCommandHandler = Callable[[MaubotMessageEvent, ...], Awaitable[None]]
PassiveCommandHandlerDecorator = NewType("PassiveCommandHandlerDecorator",
Callable[[PassiveCommandHandler], PassiveCommandHandler])
def passive(regex: Union[str, Pattern], msgtypes: Sequence[MessageType] = (MessageType.TEXT,),
field: Callable[[MaubotMessageEvent], str] = lambda event: event.content.body
) -> PassiveCommandHandlerDecorator:
if not isinstance(regex, Pattern):
regex = re.compile(regex)
def decorator(func: PassiveCommandHandler) -> PassiveCommandHandler:
@functools.wraps(func)
async def replacement(event: MaubotMessageEvent) -> None:
if event.sender == event.client.mxid:
return
elif msgtypes and event.content.msgtype not in msgtypes:
return
match = regex.match(field(event))
if match:
await func(event, *list(match.groups()))
return replacement
return decorator
class _Argument:
def __init__(self, name: str, required: bool, matches: Optional[str],
parser: Optional[Callable[[str], Any]]) -> None:
pass
def argument(name: str, *, required: bool = True, matches: Optional[str] = None,
parser: Optional[Callable[[str], Any]] = None) -> EventHandlerDecorator:
def decorator(func: EventHandler) -> EventHandler:
if not hasattr(func, "__mb_arguments__"):
func.__mb_arguments__ = []
func.__mb_arguments__.append(_Argument(name, required, matches, parser))
return func
return decorator
def vararg(func: EventHandler) -> EventHandler:
func.__mb_vararg__ = True
return func

View file

@ -23,20 +23,21 @@ from mautrix.client import EventHandler
EventHandlerDecorator = NewType("EventHandlerDecorator", Callable[[EventHandler], EventHandler])
def handler(var: Union[EventType, EventHandler]) -> Union[EventHandlerDecorator, EventHandler]:
def on(var: Union[EventType, EventHandler]) -> Union[EventHandlerDecorator, EventHandler]:
def decorator(func: EventHandler) -> EventHandler:
func.__mb_event_handler__ = True
if isinstance(var, EventType):
func.__mb_event_type__ = var
else:
func.__mb_event_type__ = EventType.ALL
return func
@functools.wraps(func)
async def wrapper(event: Event) -> None:
pass
wrapper.__mb_event_handler__ = True
if isinstance(var, EventType):
return decorator
wrapper.__mb_event_type__ = var
else:
decorator(var)
wrapper.__mb_event_type__ = EventType.ALL
return wrapper
return decorator if isinstance(var, EventType) else decorator(var)
class Field:

View file

@ -21,7 +21,7 @@ import attr
from mautrix import Client as MatrixClient
from mautrix.util.formatter import parse_html
from mautrix.types import (EventType, MessageEvent, Event, EventID, RoomID, MessageEventContent,
MessageType, TextMessageEventContent, Format, RelatesTo)
MessageType, TextMessageEventContent, Format, RelatesTo, StateEvent)
class EscapeHTML(Extension):
@ -39,12 +39,12 @@ def parse_markdown(markdown: str, allow_html: bool = False) -> Tuple[str, str]:
class MaubotMessageEvent(MessageEvent):
_client: MatrixClient
client: MatrixClient
def __init__(self, base: MessageEvent, client: MatrixClient):
super().__init__(**{a.name.lstrip("_"): getattr(base, a.name)
for a in attr.fields(MessageEvent)})
self._client = client
self.client = client
def respond(self, content: Union[str, MessageEventContent],
event_type: EventType = EventType.ROOM_MESSAGE,
@ -56,7 +56,7 @@ class MaubotMessageEvent(MessageEvent):
content.body, content.formatted_body = parse_markdown(content.body)
if reply:
content.set_reply(self)
return self._client.send_message_event(self.room_id, event_type, content)
return self.client.send_message_event(self.room_id, event_type, content)
def reply(self, content: Union[str, MessageEventContent],
event_type: EventType = EventType.ROOM_MESSAGE,
@ -64,7 +64,7 @@ class MaubotMessageEvent(MessageEvent):
return self.respond(content, event_type, markdown, reply=True)
def mark_read(self) -> Awaitable[None]:
return self._client.send_receipt(self.room_id, self.event_id, "m.read")
return self.client.send_receipt(self.room_id, self.event_id, "m.read")
class MaubotMatrixClient(MatrixClient):
@ -79,10 +79,14 @@ class MaubotMatrixClient(MatrixClient):
async def call_handlers(self, event: Event) -> None:
if isinstance(event, MessageEvent):
event = MaubotMessageEvent(event, self)
else:
event.client = self
return await super().call_handlers(event)
async def get_event(self, room_id: RoomID, event_id: EventID) -> Event:
event = await super().get_event(room_id, event_id)
if isinstance(event, MessageEvent):
return MaubotMessageEvent(event, self)
else:
event.client = self
return event

View file

@ -52,18 +52,8 @@ class Plugin(ABC):
for key in dir(self):
val = getattr(self, key)
if hasattr(val, "__mb_event_handler__"):
handle_own_events = hasattr(val, "__mb_handle_own_events__")
@functools.wraps(val)
async def handler(event: Event) -> None:
if not handle_own_events and getattr(event, "sender", "") == self.client.mxid:
return
for filter in val.__mb_event_filters__:
if not filter(event):
return
await val(event)
self._handlers_at_startup.append((handler, val.__mb_event_type__))
self.client.add_event_handler(val.__mb_event_type__, handler)
self._handlers_at_startup.append((val, val.__mb_event_type__))
self.client.add_event_handler(val.__mb_event_type__, val)
async def stop(self) -> None:
for func, event_type in self._handlers_at_startup: