# maubot - A plugin-based Matrix bot system. # Copyright (C) 2022 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Affero General Public License for more details. # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . from typing import ( Any, Awaitable, Callable, Dict, Iterable, List, NewType, Optional, Pattern, Sequence, Set, Tuple, Union, ) from abc import ABC, abstractmethod import asyncio import functools import inspect import re from mautrix.types import EventType, MessageType from ..matrix import MaubotMessageEvent from . import event PrefixType = Optional[Union[str, Callable[[], str], Callable[[Any], str]]] AliasesType = Union[ List[str], Tuple[str, ...], Set[str], Callable[[str], bool], Callable[[Any, str], bool] ] CommandHandlerFunc = NewType( "CommandHandlerFunc", Callable[[MaubotMessageEvent, Any], Awaitable[Any]] ) CommandHandlerDecorator = NewType( "CommandHandlerDecorator", Callable[[Union["CommandHandler", CommandHandlerFunc]], "CommandHandler"], ) PassiveCommandHandlerDecorator = NewType( "PassiveCommandHandlerDecorator", Callable[[CommandHandlerFunc], CommandHandlerFunc] ) def _split_in_two(val: str, split_by: str) -> List[str]: return val.split(split_by, 1) if split_by in val else [val, ""] class CommandHandler: def __init__(self, func: CommandHandlerFunc) -> None: self.__mb_func__: CommandHandlerFunc = func self.__mb_parent__: Optional[CommandHandler] = None self.__mb_subcommands__: List[CommandHandler] = [] self.__mb_arguments__: List[Argument] = [] self.__mb_help__: Optional[str] = None self.__mb_get_name__: Callable[[Any], str] = lambda s: "noname" self.__mb_is_command_match__: Callable[[Any, str], bool] = self.__command_match_unset self.__mb_require_subcommand__: bool = True self.__mb_must_consume_args__: bool = True self.__mb_arg_fallthrough__: bool = True self.__mb_event_handler__: bool = True self.__mb_event_types__: set[EventType] = {EventType.ROOM_MESSAGE} self.__mb_msgtypes__: Iterable[MessageType] = (MessageType.TEXT,) self.__bound_copies__: Dict[Any, CommandHandler] = {} self.__bound_instance__: Any = None def __get__(self, instance, instancetype): if not instance or self.__bound_instance__: return self try: return self.__bound_copies__[instance] except KeyError: new_ch = type(self)(self.__mb_func__) keys = [ "parent", "subcommands", "arguments", "help", "get_name", "is_command_match", "require_subcommand", "must_consume_args", "arg_fallthrough", "event_handler", "event_types", "msgtypes", ] for key in keys: key = f"__mb_{key}__" setattr(new_ch, key, getattr(self, key)) new_ch.__bound_instance__ = instance new_ch.__mb_subcommands__ = [ subcmd.__get__(instance, instancetype) for subcmd in self.__mb_subcommands__ ] self.__bound_copies__[instance] = new_ch return new_ch @staticmethod def __command_match_unset(self, val: str) -> bool: raise NotImplementedError("Hmm") async def __call__( self, evt: MaubotMessageEvent, *, _existing_args: Dict[str, Any] = None, remaining_val: str = None, ) -> Any: if evt.sender == evt.client.mxid or evt.content.msgtype not in self.__mb_msgtypes__: return if remaining_val is None: if not evt.content.body or evt.content.body[0] != "!": return command, remaining_val = _split_in_two(evt.content.body[1:], " ") command = command.lower() if not self.__mb_is_command_match__(self.__bound_instance__, command): return call_args: Dict[str, Any] = {**_existing_args} if _existing_args else {} if not self.__mb_arg_fallthrough__ and len(self.__mb_subcommands__) > 0: ok, res = await self.__call_subcommand__(evt, call_args, remaining_val) if ok: return res ok, remaining_val = await self.__parse_args__(evt, call_args, remaining_val) if not ok: return elif self.__mb_arg_fallthrough__ and len(self.__mb_subcommands__) > 0: ok, res = await self.__call_subcommand__(evt, call_args, remaining_val) if ok: return res elif self.__mb_require_subcommand__: await evt.reply(self.__mb_full_help__) return if self.__mb_must_consume_args__ and remaining_val.strip(): await evt.reply(self.__mb_full_help__) return if self.__bound_instance__: return await self.__mb_func__(self.__bound_instance__, evt, **call_args) return await self.__mb_func__(evt, **call_args) async def __call_subcommand__( self, evt: MaubotMessageEvent, call_args: Dict[str, Any], remaining_val: str ) -> Tuple[bool, Any]: command, remaining_val = _split_in_two(remaining_val.strip(), " ") for subcommand in self.__mb_subcommands__: if subcommand.__mb_is_command_match__(subcommand.__bound_instance__, command): return True, await subcommand( evt, _existing_args=call_args, remaining_val=remaining_val ) return False, None async def __parse_args__( self, evt: MaubotMessageEvent, call_args: Dict[str, Any], remaining_val: str ) -> Tuple[bool, str]: for arg in self.__mb_arguments__: try: remaining_val, call_args[arg.name] = arg.match( remaining_val.strip(), evt=evt, instance=self.__bound_instance__ ) if arg.required and call_args[arg.name] is None: raise ValueError("Argument required") except ArgumentSyntaxError as e: await evt.reply(e.message + (f"\n{self.__mb_usage__}" if e.show_usage else "")) return False, remaining_val except ValueError: await evt.reply(self.__mb_usage__) return False, remaining_val return True, remaining_val @property def __mb_full_help__(self) -> str: usage = self.__mb_usage_without_subcommands__ + "\n\n" if not self.__mb_require_subcommand__: usage += f"* {self.__mb_prefix__} {self.__mb_usage_args__} - {self.__mb_help__}\n" usage += "\n".join(cmd.__mb_usage_inline__ for cmd in self.__mb_subcommands__) return usage @property def __mb_usage_args__(self) -> str: arg_usage = " ".join( f"<{arg.label}>" if arg.required else f"[{arg.label}]" for arg in self.__mb_arguments__ ) if self.__mb_subcommands__ and self.__mb_arg_fallthrough__: arg_usage += " " + self.__mb_usage_subcommand__ return arg_usage @property def __mb_usage_subcommand__(self) -> str: return f" [...]" @property def __mb_name__(self) -> str: return self.__mb_get_name__(self.__bound_instance__) @property def __mb_prefix__(self) -> str: if self.__mb_parent__: return ( f"!{self.__mb_parent__.__mb_get_name__(self.__bound_instance__)} " f"{self.__mb_name__}" ) return f"!{self.__mb_name__}" @property def __mb_usage_inline__(self) -> str: if not self.__mb_arg_fallthrough__: return ( f"* {self.__mb_name__} {self.__mb_usage_args__} - {self.__mb_help__}\n" f"* {self.__mb_name__} {self.__mb_usage_subcommand__}" ) return f"* {self.__mb_name__} {self.__mb_usage_args__} - {self.__mb_help__}" @property def __mb_subcommands_list__(self) -> str: return f"**Subcommands:** {', '.join(sc.__mb_name__ for sc in self.__mb_subcommands__)}" @property def __mb_usage_without_subcommands__(self) -> str: if not self.__mb_arg_fallthrough__: if not self.__mb_arguments__: return f"**Usage:** {self.__mb_prefix__} [subcommand] [...]" return ( f"**Usage:** {self.__mb_prefix__} {self.__mb_usage_args__}" f" _OR_ {self.__mb_prefix__} {self.__mb_usage_subcommand__}" ) return f"**Usage:** {self.__mb_prefix__} {self.__mb_usage_args__}" @property def __mb_usage__(self) -> str: if len(self.__mb_subcommands__) > 0: return f"{self.__mb_usage_without_subcommands__} \n{self.__mb_subcommands_list__}" return self.__mb_usage_without_subcommands__ def subcommand( self, name: PrefixType = None, *, help: str = None, aliases: AliasesType = None, required_subcommand: bool = True, arg_fallthrough: bool = True, ) -> CommandHandlerDecorator: def decorator(func: Union[CommandHandler, CommandHandlerFunc]) -> CommandHandler: if not isinstance(func, CommandHandler): func = CommandHandler(func) new( name, help=help, aliases=aliases, require_subcommand=required_subcommand, arg_fallthrough=arg_fallthrough, )(func) func.__mb_parent__ = self func.__mb_event_handler__ = False self.__mb_subcommands__.append(func) return func return decorator def new( name: PrefixType = None, *, help: str = None, aliases: AliasesType = None, event_type: EventType = EventType.ROOM_MESSAGE, msgtypes: Iterable[MessageType] = None, require_subcommand: bool = True, arg_fallthrough: bool = True, must_consume_args: bool = True, ) -> CommandHandlerDecorator: def decorator(func: Union[CommandHandler, CommandHandlerFunc]) -> CommandHandler: if not isinstance(func, CommandHandler): func = CommandHandler(func) func.__mb_help__ = help if name: if callable(name): if len(inspect.getfullargspec(name).args) == 0: func.__mb_get_name__ = lambda self: name() else: func.__mb_get_name__ = name else: func.__mb_get_name__ = lambda self: name else: func.__mb_get_name__ = lambda self: func.__mb_func__.__name__.replace("_", "-") if callable(aliases): if len(inspect.getfullargspec(aliases).args) == 1: func.__mb_is_command_match__ = lambda self, val: aliases(val) else: func.__mb_is_command_match__ = aliases elif isinstance(aliases, (list, set, tuple)): func.__mb_is_command_match__ = lambda self, val: ( val == func.__mb_get_name__(self) or val in aliases ) else: func.__mb_is_command_match__ = lambda self, val: val == func.__mb_get_name__(self) # Decorators are executed last to first, so we reverse the argument list. func.__mb_arguments__.reverse() func.__mb_require_subcommand__ = require_subcommand func.__mb_arg_fallthrough__ = arg_fallthrough func.__mb_must_consume_args__ = must_consume_args func.__mb_event_types__ = {event_type} if msgtypes: func.__mb_msgtypes__ = msgtypes return func return decorator class ArgumentSyntaxError(ValueError): def __init__(self, message: str, show_usage: bool = True) -> None: super().__init__(message) self.message = message self.show_usage = show_usage class Argument(ABC): def __init__( self, name: str, label: str = None, *, required: bool = False, pass_raw: bool = False ) -> None: self.name = name self.label = label or name self.required = required self.pass_raw = pass_raw @abstractmethod def match(self, val: str, **kwargs) -> Tuple[str, Any]: pass def __call__(self, func: Union[CommandHandler, CommandHandlerFunc]) -> CommandHandler: if not isinstance(func, CommandHandler): func = CommandHandler(func) func.__mb_arguments__.append(self) return func class RegexArgument(Argument): def __init__( self, name: str, label: str = None, *, required: bool = False, pass_raw: bool = False, matches: str = None, ) -> None: super().__init__(name, label, required=required, pass_raw=pass_raw) matches = f"^{matches}" if self.pass_raw else f"^{matches}$" self.regex = re.compile(matches) def match(self, val: str, **kwargs) -> Tuple[str, Any]: orig_val = val if not self.pass_raw: val = re.split(r"\s", val, 1)[0] match = self.regex.match(val) if match: return ( orig_val[: match.start()] + orig_val[match.end() :], match.groups() or val[match.start() : match.end()], ) return orig_val, None class CustomArgument(Argument): def __init__( self, name: str, label: str = None, *, required: bool = False, pass_raw: bool = False, matcher: Callable[[str], Any], ) -> None: super().__init__(name, label, required=required, pass_raw=pass_raw) self.matcher = matcher def match(self, val: str, **kwargs) -> Tuple[str, Any]: if self.pass_raw: return self.matcher(val) orig_val = val val = re.split(r"\s", val, 1)[0] res = self.matcher(val) if res is not None: return orig_val[len(val) :], res return orig_val, None class SimpleArgument(Argument): def match(self, val: str, **kwargs) -> Tuple[str, Any]: if self.pass_raw: return "", val res = re.split(r"\s", val, 1)[0] return val[len(res) :], res def argument( name: str, label: str = None, *, required: bool = True, matches: Optional[str] = None, parser: Optional[Callable[[str], Any]] = None, pass_raw: bool = False, ) -> CommandHandlerDecorator: if matches: return RegexArgument(name, label, required=required, matches=matches, pass_raw=pass_raw) elif parser: return CustomArgument(name, label, required=required, matcher=parser, pass_raw=pass_raw) else: return SimpleArgument(name, label, required=required, pass_raw=pass_raw) def passive( regex: Union[str, Pattern], *, msgtypes: Sequence[MessageType] = (MessageType.TEXT,), field: Callable[[MaubotMessageEvent], str] = lambda evt: evt.content.body, event_type: EventType = EventType.ROOM_MESSAGE, multiple: bool = False, case_insensitive: bool = False, multiline: bool = False, dot_all: bool = False, ) -> PassiveCommandHandlerDecorator: if not isinstance(regex, Pattern): flags = re.RegexFlag.UNICODE if case_insensitive: flags |= re.IGNORECASE if multiline: flags |= re.MULTILINE if dot_all: flags |= re.DOTALL regex = re.compile(regex, flags=flags) def decorator(func: CommandHandlerFunc) -> CommandHandlerFunc: combine = None if hasattr(func, "__mb_passive_orig__"): combine = func func = func.__mb_passive_orig__ @event.on(event_type) @functools.wraps(func) async def replacement(self, evt: MaubotMessageEvent = None) -> None: if not evt and isinstance(self, MaubotMessageEvent): evt = self self = None if evt.sender == evt.client.mxid: return elif msgtypes and evt.content.msgtype not in msgtypes: return data = field(evt) if multiple: val = [ (data[match.pos : match.endpos], *match.groups()) for match in regex.finditer(data) ] else: match = regex.search(data) if match: val = (data[match.pos : match.endpos], *match.groups()) else: val = None if val: if self: await func(self, evt, val) else: await func(evt, val) if combine: orig_replacement = replacement @event.on(event_type) @functools.wraps(func) async def replacement(self, evt: MaubotMessageEvent = None) -> None: await asyncio.gather(combine(self, evt), orig_replacement(self, evt)) replacement.__mb_passive_orig__ = func return replacement return decorator