diff --git a/maubot/handlers/command.py b/maubot/handlers/command.py
index 5e1b894..f850d12 100644
--- a/maubot/handlers/command.py
+++ b/maubot/handlers/command.py
@@ -13,7 +13,10 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import Union, Callable, Sequence, Pattern, Awaitable, NewType, Optional, Any, List, Dict
+from typing import (Union, Callable, Sequence, Pattern, Awaitable, NewType, Optional, Any, List,
+ Dict, Tuple)
+from abc import ABC, abstractmethod
+import asyncio
import functools
import re
@@ -41,42 +44,65 @@ class CommandHandler:
self.__mb_name__: str = None
self.__mb_prefix__: str = None
self.__mb_require_subcommand__: bool = True
+ self.__mb_arg_fallthrough__: bool = True
self.__mb_event_handler__: bool = True
self.__mb_event_type__: EventType = EventType.ROOM_MESSAGE
self.__class_instance: Any = None
- async def __call__(self, evt: MaubotMessageEvent, *,
- _existing_args: Dict[str, Any] = None) -> Any:
+ async def __call__(self, evt: MaubotMessageEvent, *, _existing_args: Dict[str, Any] = None,
+ _remaining_val: str = None) -> Any:
body = evt.content.body
- if evt.sender == evt.client.mxid or not body.startswith(self.__mb_prefix__):
+ has_prefix = _remaining_val or body.startswith(self.__mb_prefix__)
+ if evt.sender == evt.client.mxid or not has_prefix:
return
call_args: Dict[str, Any] = {**_existing_args} if _existing_args else {}
- remaining_val = body[len(self.__mb_prefix__) + 1:]
- # TODO update remaining_val somehow
+ remaining_val = _remaining_val or body[len(self.__mb_prefix__) + 1:]
+
+ if not self.__mb_arg_fallthrough__ and len(self.__mb_subcommands__) > 0:
+ ok, res = await self.__call_subcommand__(evt, call_args, remaining_val)
+ if ok:
+ return res
+
+ ok, remaining_val = await self.__parse_args__(evt, call_args, remaining_val)
+ if not ok:
+ return
+ elif self.__mb_arg_fallthrough__ and len(self.__mb_subcommands__) > 0:
+ ok, res = await self.__call_subcommand__(evt, call_args, remaining_val)
+ if ok:
+ return res
+ elif self.__mb_require_subcommand__:
+ await evt.reply(self.__mb_full_help__)
+ return
+
+ if self.__class_instance:
+ return await self.__mb_func__(self.__class_instance, evt, **call_args)
+ return await self.__mb_func__(evt, **call_args)
+
+ async def __call_subcommand__(self, evt: MaubotMessageEvent, call_args: Dict[str, Any],
+ remaining_val: str) -> Tuple[bool, Any]:
+ remaining_val = remaining_val.strip()
+ split = remaining_val.split(" ") if len(remaining_val) > 0 else []
+ try:
+ subcommand = self.__mb_subcommands__[split[0]]
+ return True, await subcommand(evt, _existing_args=call_args,
+ _remaining_val=" ".join(split[1:]))
+ except (KeyError, IndexError):
+ return False, None
+
+ async def __parse_args__(self, evt: MaubotMessageEvent, call_args: Dict[str, Any],
+ remaining_val: str) -> Tuple[bool, str]:
for arg in self.__mb_arguments__:
try:
- call_args[arg.name] = arg.match(remaining_val)
+ remaining_val, call_args[arg.name] = arg.match(remaining_val.strip())
if arg.required and not call_args[arg.name]:
raise ValueError("Argument required")
except ArgumentSyntaxError as e:
await evt.reply(e.message + (f"\n{self.__mb_usage__}" if e.show_usage else ""))
- return
+ return False, remaining_val
except ValueError as e:
await evt.reply(self.__mb_usage__)
- return
-
- if len(self.__mb_subcommands__) > 0:
- split = remaining_val.split(" ") if len(remaining_val) > 0 else []
- try:
- subcommand = self.__mb_subcommands__[split[0]]
- return await subcommand(evt, _existing_args=call_args)
- except (KeyError, IndexError):
- if self.__mb_require_subcommand__:
- await evt.reply(self.__mb_full_help__)
- return
- return (await self.__mb_func__(self.__class_instance, evt, **call_args)
- if self.__class_instance
- else await self.__mb_func__(evt, **call_args))
+ return False, remaining_val
+ return True, remaining_val
def __get__(self, instance, instancetype):
self.__class_instance = instance
@@ -84,20 +110,45 @@ class CommandHandler:
@property
def __mb_full_help__(self) -> str:
- basic = self.__mb_usage__
- usage = f"{basic} [...]\n\n"
- usage += "\n".join(f"* {cmd.__mb_name__} {cmd.__mb_usage_args__} - {cmd.__mb_help__}"
- for cmd in self.__mb_subcommands__.values())
+ usage = self.__mb_usage_without_subcommands__ + "\n\n"
+ usage += "\n".join(cmd.__mb_usage_inline__ for cmd in self.__mb_subcommands__.values())
return usage
@property
def __mb_usage_args__(self) -> str:
- return " ".join(f"<{arg.label}>" if arg.required else f"[{arg.label}]"
- for arg in self.__mb_arguments__)
+ arg_usage = " ".join(f"<{arg.label}>" if arg.required else f"[{arg.label}]"
+ for arg in self.__mb_arguments__)
+ if self.__mb_subcommands__ and self.__mb_arg_fallthrough__:
+ arg_usage += " " + self.__mb_usage_subcommand__
+ return arg_usage
+
+ @property
+ def __mb_usage_subcommand__(self) -> str:
+ return f" [...]"
+
+ @property
+ def __mb_usage_inline__(self) -> str:
+ if not self.__mb_arg_fallthrough__:
+ return (f"* {self.__mb_name__} {self.__mb_usage_args__} - {self.__mb_help__}\n"
+ f"* {self.__mb_name__} {self.__mb_usage_subcommand__}")
+ return f"* {self.__mb_name__} {self.__mb_usage_args__} - {self.__mb_help__}"
+
+ @property
+ def __mb_subcommands_list__(self) -> str:
+ return f"**Subcommands:** {', '.join(self.__mb_subcommands__.keys())}"
+
+ @property
+ def __mb_usage_without_subcommands__(self) -> str:
+ if not self.__mb_arg_fallthrough__:
+ return (f"**Usage:** {self.__mb_prefix__} {self.__mb_usage_args__}"
+ f" _OR_ {self.__mb_usage_subcommand__}")
+ return f"**Usage:** {self.__mb_prefix__} {self.__mb_usage_args__}"
@property
def __mb_usage__(self) -> str:
- return f"**Usage:** {self.__mb_prefix__} {self.__mb_usage_args__}"
+ if len(self.__mb_subcommands__) > 0:
+ return f"{self.__mb_usage_without_subcommands__} \n{self.__mb_subcommands_list__}"
+ return self.__mb_usage_without_subcommands__
def subcommand(self, name: PrefixType = None, help: str = None
) -> CommandHandlerDecorator:
@@ -114,6 +165,22 @@ class CommandHandler:
return decorator
+def new(name: PrefixType, *, help: str = None, event_type: EventType = EventType.ROOM_MESSAGE,
+ require_subcommand: bool = True, arg_fallthrough: bool = True) -> CommandHandlerDecorator:
+ def decorator(func: Union[CommandHandler, CommandHandlerFunc]) -> CommandHandler:
+ if not isinstance(func, CommandHandler):
+ func = CommandHandler(func)
+ func.__mb_help__ = help
+ func.__mb_name__ = name or func.__name__
+ func.__mb_require_subcommand__ = require_subcommand
+ func.__mb_arg_fallthrough__ = arg_fallthrough
+ func.__mb_prefix__ = f"!{func.__mb_name__}"
+ func.__mb_event_type__ = event_type
+ return func
+
+ return decorator
+
+
class ArgumentSyntaxError(ValueError):
def __init__(self, message: str, show_usage: bool = True) -> None:
super().__init__(message)
@@ -121,36 +188,17 @@ class ArgumentSyntaxError(ValueError):
self.show_usage = show_usage
-class Argument:
+class Argument(ABC):
def __init__(self, name: str, label: str = None, *, required: bool = False,
- matches: Optional[str] = None, parser: Optional[Callable[[str], Any]] = None,
pass_raw: bool = False) -> None:
self.name = name
- self.required = required
self.label = label or name
+ self.required = required
+ self.pass_raw = pass_raw
- if not parser:
- if matches:
- regex = re.compile(matches)
-
- def parser(val: str) -> Optional[Sequence[str]]:
- match = regex.match(val)
- return match.groups() if match else None
- else:
- def parser(val: str) -> str:
- return val
-
- if not pass_raw:
- o_parser = parser
-
- def parser(val: str) -> Any:
- val = val.strip().split(" ")
- return o_parser(val[0])
-
- self.parser = parser
-
- def match(self, val: str) -> Any:
- return self.parser(val)
+ @abstractmethod
+ def match(self, val: str) -> Tuple[str, Any]:
+ pass
def __call__(self, func: Union[CommandHandler, CommandHandlerFunc]) -> CommandHandler:
if not isinstance(func, CommandHandler):
@@ -159,49 +207,108 @@ class Argument:
return func
-def new(name: PrefixType, *, help: str = None, event_type: EventType = EventType.ROOM_MESSAGE,
- require_subcommand: bool = True) -> CommandHandlerDecorator:
- def decorator(func: Union[CommandHandler, CommandHandlerFunc]) -> CommandHandler:
- if not isinstance(func, CommandHandler):
- func = CommandHandler(func)
- func.__mb_help__ = help
- func.__mb_name__ = name or func.__name__
- func.__mb_require_subcommand__ = require_subcommand
- func.__mb_prefix__ = f"!{func.__mb_name__}"
- func.__mb_event_type__ = event_type
- return func
+class RegexArgument(Argument):
+ def __init__(self, name: str, label: str = None, *, required: bool = False,
+ pass_raw: bool = False, matches: str = None) -> None:
+ super().__init__(name, label, required=required, pass_raw=pass_raw)
+ matches = f"^{matches}" if self.pass_raw else f"^{matches}$"
+ self.regex = re.compile(matches)
- return decorator
+ def match(self, val: str) -> Tuple[str, Any]:
+ orig_val = val
+ if not self.pass_raw:
+ val = val.split(" ")[0]
+ match = self.regex.match(val)
+ if match:
+ return (orig_val[:match.pos] + orig_val[match.endpos:],
+ match.groups() or val[match.pos:match.endpos])
+ return orig_val, None
+
+
+class CustomArgument(Argument):
+ def __init__(self, name: str, label: str = None, *, required: bool = False,
+ pass_raw: bool = False, matcher: Callable[[str], Any]) -> None:
+ super().__init__(name, label, required=required, pass_raw=pass_raw)
+ self.matcher = matcher
+
+ def match(self, val: str) -> Tuple[str, Any]:
+ if self.pass_raw:
+ return self.matcher(val)
+ orig_val = val
+ val = val.split(" ")[0]
+ res = self.matcher(val)
+ if res:
+ return orig_val[len(val):], res
+ return orig_val, None
+
+
+class SimpleArgument(Argument):
+ def match(self, val: str) -> Tuple[str, Any]:
+ if self.pass_raw:
+ return "", val
+ res = val.split(" ")[0]
+ return val[len(res):], res
def argument(name: str, label: str = None, *, required: bool = True, matches: Optional[str] = None,
- parser: Optional[Callable[[str], Any]] = None) -> CommandHandlerDecorator:
- return Argument(name, label, required=required, matches=matches, parser=parser)
+ parser: Optional[Callable[[str], Any]] = None, pass_raw: bool = False
+ ) -> CommandHandlerDecorator:
+ if matches:
+ return RegexArgument(name, label, required=required, matches=matches, pass_raw=pass_raw)
+ elif parser:
+ return CustomArgument(name, label, required=required, matcher=parser, pass_raw=pass_raw)
+ else:
+ return SimpleArgument(name, label, required=required, pass_raw=pass_raw)
-def passive(regex: Union[str, Pattern], msgtypes: Sequence[MessageType] = (MessageType.TEXT,),
- field: Callable[[MaubotMessageEvent], str] = lambda event: event.content.body,
- event_type: EventType = EventType.ROOM_MESSAGE) -> PassiveCommandHandlerDecorator:
+def passive(regex: Union[str, Pattern], *, msgtypes: Sequence[MessageType] = (MessageType.TEXT,),
+ field: Callable[[MaubotMessageEvent], str] = lambda evt: evt.content.body,
+ event_type: EventType = EventType.ROOM_MESSAGE, multiple: bool = False
+ ) -> PassiveCommandHandlerDecorator:
if not isinstance(regex, Pattern):
regex = re.compile(regex)
def decorator(func: CommandHandlerFunc) -> CommandHandlerFunc:
+ combine = None
+ if hasattr(func, "__mb_passive_orig__"):
+ combine = func
+ func = func.__mb_passive_orig__
+
@event.on(event_type)
@functools.wraps(func)
- async def replacement(self, evt: MaubotMessageEvent) -> None:
- if isinstance(self, MaubotMessageEvent):
+ async def replacement(self, evt: MaubotMessageEvent = None) -> None:
+ if not evt and isinstance(self, MaubotMessageEvent):
evt = self
self = None
if evt.sender == evt.client.mxid:
return
elif msgtypes and evt.content.msgtype not in msgtypes:
return
- match = regex.match(field(evt))
- if match:
- if self:
- await func(self, evt, *list(match.groups()))
+ data = field(evt)
+ if multiple:
+ val = [(data[match.pos:match.endpos], *match.groups())
+ for match in regex.finditer(data)]
+ else:
+ match = regex.match(data)
+ if match:
+ val = (data[match.pos:match.endpos], *match.groups())
else:
- await func(evt, *list(match.groups()))
+ val = None
+ if val:
+ if self:
+ await func(self, evt, val)
+ else:
+ await func(evt, val)
+
+ if combine:
+ orig_replacement = replacement
+
+ @event.on(event_type)
+ @functools.wraps(func)
+ async def replacement(self, evt: MaubotMessageEvent = None) -> None:
+ await asyncio.gather(combine(self, evt), orig_replacement(self, evt))
+
+ replacement.__mb_passive_orig__ = func
return replacement