Merge pull request #38 from maubot/rewrite-command-handling

Rewrite command handling
This commit is contained in:
Tulir Asokan 2018-12-26 20:56:31 +02:00 committed by GitHub
commit 0a39c1365d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 449 additions and 252 deletions

View file

@ -1,6 +1,9 @@
# This is an example maubot plugin definition file.
# All plugins must include a file like this named "maubot.yaml" in their root directory.
# Target maubot version
maubot: 0.1.0
# The unique ID for the plugin. Java package naming style. (i.e. use your own domain, not xyz.maubot)
id: xyz.maubot.example
@ -24,6 +27,9 @@ modules:
# The main class must extend maubot.Plugin
main_class: HelloWorldBot
# Whether or not instances need a database
database: false
# Extra files that the upcoming build tool should include in the mbp file.
#extra_files:
#- base-config.yaml

View file

@ -1,3 +1,2 @@
from .plugin_base import Plugin
from .command_spec import CommandSpec, Command, PassiveCommand, Argument
from .matrix import MaubotMatrixClient as Client, MaubotMessageEvent as MessageEvent

View file

@ -122,6 +122,8 @@ def upload_plugin(output: Union[str, IO], server: str) -> None:
@click.option("-s", "--server", help="Server to upload built plugin to")
def build(path: str, output: str, upload: bool, server: str) -> None:
meta = read_meta(path)
if not meta:
return
if output or not upload:
output = read_output_path(output, meta)
if not output:

View file

@ -55,7 +55,7 @@ class Client:
token=self.access_token, client_session=self.http_client,
log=self.log, loop=self.loop, store=self.db_instance)
if self.autojoin:
self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
self.client.add_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
async def start(self, try_n: Optional[int] = 0) -> None:
try:
@ -260,9 +260,9 @@ class Client:
if value == self.db_instance.autojoin:
return
if value:
self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
self.client.add_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
else:
self.client.remove_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
self.client.remove_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
self.db_instance.autojoin = value
@property

View file

@ -1,155 +0,0 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import List, Dict, Pattern, Union, Tuple, Optional, Any
from attr import dataclass
import re
from mautrix.types import MessageEvent, MatchedCommand, MatchedPassiveCommand
from mautrix.client.api.types.util import SerializableAttrs
@dataclass
class Argument(SerializableAttrs['Argument']):
matches: str
required: bool = False
description: str = None
@dataclass
class Command(SerializableAttrs['Command']):
syntax: str
arguments: Dict[str, Argument] = {}
description: str = None
@dataclass
class PassiveCommand(SerializableAttrs['PassiveCommand']):
name: str
matches: str
match_against: str
match_event: MessageEvent = None
class ParsedCommand:
name: str
is_passive: bool
arguments: List[str]
starts_with: str
matches: Pattern
match_against: str
match_event: MessageEvent
def __init__(self, command: Union[PassiveCommand, Command]) -> None:
if isinstance(command, PassiveCommand):
self._init_passive(command)
elif isinstance(command, Command):
self._init_active(command)
else:
raise ValueError("Command parameter must be a Command or a PassiveCommand.")
def _init_passive(self, command: PassiveCommand) -> None:
self.name = command.name
self.is_passive = True
self.match_against = command.match_against
self.matches = re.compile(command.matches, re.UNICODE)
self.match_event = command.match_event
def _init_active(self, command: Command) -> None:
self.name = command.syntax
self.is_passive = False
self.arguments = []
regex_builder = []
sw_builder = []
argument_encountered = False
for word in command.syntax.split(" "):
arg = command.arguments.get(word, None)
if arg is not None and len(word) > 0:
argument_encountered = True
regex = f"({arg.matches})"
if not arg.required:
regex += "?"
self.arguments.append(word)
regex_builder.append(regex)
else:
if not argument_encountered:
sw_builder.append(word)
regex_builder.append(re.escape(word))
self.starts_with = "!" + " ".join(sw_builder)
self.matches = re.compile("^!" + " ".join(regex_builder) + "$", re.UNICODE)
self.match_against = "body"
def match(self, evt: MessageEvent) -> bool:
return self._match_passive(evt) if self.is_passive else self._match_active(evt)
@staticmethod
def _parse_key(key: str) -> Tuple[str, Optional[str]]:
if '.' not in key:
return key, None
key, next_key = key.split('.', 1)
if len(key) > 0 and key[0] == "[":
end_index = next_key.index("]")
key = key[1:] + "." + next_key[:end_index]
next_key = next_key[end_index + 2:] if len(next_key) > end_index + 1 else None
return key, next_key
@classmethod
def _recursive_get(cls, data: Any, key: str) -> Any:
if not data:
return None
key, next_key = cls._parse_key(key)
if next_key is not None:
return cls._recursive_get(data[key], next_key)
return data[key]
def _match_passive(self, evt: MessageEvent) -> bool:
try:
match_against = self._recursive_get(evt.content, self.match_against)
except KeyError:
match_against = None
match_against = match_against or evt.content.body
matches = [[match.string[match.start():match.end()]] + list(match.groups())
for match in self.matches.finditer(match_against)]
if not matches:
return False
if evt.unsigned.passive_command is None:
evt.unsigned.passive_command = {}
evt.unsigned.passive_command[self.name] = MatchedPassiveCommand(captured=matches)
return True
def _match_active(self, evt: MessageEvent) -> bool:
if not evt.content.body.startswith(self.starts_with):
return False
match = self.matches.match(evt.content.body)
if not match:
return False
evt.content.command = MatchedCommand(matched=self.name,
arguments=dict(zip(self.arguments, match.groups())))
return True
@dataclass
class CommandSpec(SerializableAttrs['CommandSpec']):
commands: List[Command] = []
passive_commands: List[PassiveCommand] = []
def __add__(self, other: 'CommandSpec') -> 'CommandSpec':
return CommandSpec(commands=self.commands + other.commands,
passive_commands=self.passive_commands + other.passive_commands)
def parse(self) -> List[ParsedCommand]:
return [ParsedCommand(command) for command in self.commands + self.passive_commands]

View file

@ -29,7 +29,8 @@ class Config(BaseFileConfig):
return "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(64))
def do_update(self, helper: ConfigUpdateHelper) -> None:
base, copy, _ = helper
base = helper.base
copy = helper.copy
copy("database")
copy("plugin_directories.upload")
copy("plugin_directories.load")

View file

@ -0,0 +1 @@
from . import event, command

359
maubot/handlers/command.py Normal file
View file

@ -0,0 +1,359 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import (Union, Callable, Sequence, Pattern, Awaitable, NewType, Optional, Any, List,
Dict, Tuple, Set)
from abc import ABC, abstractmethod
import asyncio
import functools
import inspect
import re
from mautrix.types import MessageType, EventType
from ..matrix import MaubotMessageEvent
from . import event
PrefixType = Optional[Union[str, Callable[[], str]]]
AliasesType = Union[List[str], Tuple[str, ...], Set[str], Callable[[str], bool]]
CommandHandlerFunc = NewType("CommandHandlerFunc",
Callable[[MaubotMessageEvent, Any], Awaitable[Any]])
CommandHandlerDecorator = NewType("CommandHandlerDecorator",
Callable[[Union['CommandHandler', CommandHandlerFunc]],
'CommandHandler'])
PassiveCommandHandlerDecorator = NewType("PassiveCommandHandlerDecorator",
Callable[[CommandHandlerFunc], CommandHandlerFunc])
def _split_in_two(val: str, split_by: str) -> List[str]:
return val.split(split_by, 1) if split_by in val else [val, ""]
class CommandHandler:
def __init__(self, func: CommandHandlerFunc) -> None:
self.__mb_func__: CommandHandlerFunc = func
self.__mb_parent__: CommandHandler = None
self.__mb_subcommands__: List[CommandHandler] = []
self.__mb_arguments__: List[Argument] = []
self.__mb_help__: str = None
self.__mb_get_name__: Callable[[], str] = None
self.__mb_is_command_match__: Callable[[Any, str], bool] = self.__command_match_unset
self.__mb_require_subcommand__: bool = True
self.__mb_arg_fallthrough__: bool = True
self.__mb_event_handler__: bool = True
self.__mb_event_type__: EventType = EventType.ROOM_MESSAGE
self.__class_instance: Any = None
@staticmethod
def __command_match_unset(self, val: str) -> str:
raise NotImplementedError("Hmm")
async def __call__(self, evt: MaubotMessageEvent, *, _existing_args: Dict[str, Any] = None,
remaining_val: str = None) -> Any:
if evt.sender == evt.client.mxid:
return
if remaining_val is None:
if not evt.content.body or evt.content.body[0] != "!":
return
command, remaining_val = _split_in_two(evt.content.body[1:], " ")
if not self.__mb_is_command_match__(self, command):
return
call_args: Dict[str, Any] = {**_existing_args} if _existing_args else {}
if not self.__mb_arg_fallthrough__ and len(self.__mb_subcommands__) > 0:
ok, res = await self.__call_subcommand__(evt, call_args, remaining_val)
if ok:
return res
ok, remaining_val = await self.__parse_args__(evt, call_args, remaining_val)
if not ok:
return
elif self.__mb_arg_fallthrough__ and len(self.__mb_subcommands__) > 0:
ok, res = await self.__call_subcommand__(evt, call_args, remaining_val)
if ok:
return res
elif self.__mb_require_subcommand__:
await evt.reply(self.__mb_full_help__)
return
if self.__class_instance:
return await self.__mb_func__(self.__class_instance, evt, **call_args)
return await self.__mb_func__(evt, **call_args)
async def __call_subcommand__(self, evt: MaubotMessageEvent, call_args: Dict[str, Any],
remaining_val: str) -> Tuple[bool, Any]:
command, remaining_val = _split_in_two(remaining_val.strip(), " ")
for subcommand in self.__mb_subcommands__:
if subcommand.__mb_is_command_match__(subcommand.__class_instance, command):
return True, await subcommand(evt, _existing_args=call_args,
remaining_val=remaining_val)
return False, None
async def __parse_args__(self, evt: MaubotMessageEvent, call_args: Dict[str, Any],
remaining_val: str) -> Tuple[bool, str]:
for arg in self.__mb_arguments__:
try:
remaining_val, call_args[arg.name] = arg.match(remaining_val.strip())
if arg.required and not call_args[arg.name]:
raise ValueError("Argument required")
except ArgumentSyntaxError as e:
await evt.reply(e.message + (f"\n{self.__mb_usage__}" if e.show_usage else ""))
return False, remaining_val
except ValueError as e:
await evt.reply(self.__mb_usage__)
return False, remaining_val
return True, remaining_val
def __get__(self, instance, instancetype):
self.__class_instance = instance
return self
@property
def __mb_full_help__(self) -> str:
usage = self.__mb_usage_without_subcommands__ + "\n\n"
usage += "\n".join(cmd.__mb_usage_inline__ for cmd in self.__mb_subcommands__)
return usage
@property
def __mb_usage_args__(self) -> str:
arg_usage = " ".join(f"<{arg.label}>" if arg.required else f"[{arg.label}]"
for arg in self.__mb_arguments__)
if self.__mb_subcommands__ and self.__mb_arg_fallthrough__:
arg_usage += " " + self.__mb_usage_subcommand__
return arg_usage
@property
def __mb_usage_subcommand__(self) -> str:
return f"<subcommand> [...]"
@property
def __mb_name__(self) -> str:
return self.__mb_get_name__(self.__class_instance)
@property
def __mb_prefix__(self) -> str:
if self.__mb_parent__:
return f"{self.__mb_parent__.__mb_prefix__} {self.__mb_name__}"
return f"!{self.__mb_name__}"
@property
def __mb_usage_inline__(self) -> str:
if not self.__mb_arg_fallthrough__:
return (f"* {self.__mb_name__} {self.__mb_usage_args__} - {self.__mb_help__}\n"
f"* {self.__mb_name__} {self.__mb_usage_subcommand__}")
return f"* {self.__mb_name__} {self.__mb_usage_args__} - {self.__mb_help__}"
@property
def __mb_subcommands_list__(self) -> str:
return f"**Subcommands:** {', '.join(self.__mb_subcommands__.keys())}"
@property
def __mb_usage_without_subcommands__(self) -> str:
if not self.__mb_arg_fallthrough__:
return (f"**Usage:** {self.__mb_prefix__} {self.__mb_usage_args__}"
f" _OR_ {self.__mb_usage_subcommand__}")
return f"**Usage:** {self.__mb_prefix__} {self.__mb_usage_args__}"
@property
def __mb_usage__(self) -> str:
if len(self.__mb_subcommands__) > 0:
return f"{self.__mb_usage_without_subcommands__} \n{self.__mb_subcommands_list__}"
return self.__mb_usage_without_subcommands__
def subcommand(self, name: PrefixType = None, *, help: str = None, aliases: AliasesType = None,
required_subcommand: bool = True, arg_fallthrough: bool = True,
) -> CommandHandlerDecorator:
def decorator(func: Union[CommandHandler, CommandHandlerFunc]) -> CommandHandler:
if not isinstance(func, CommandHandler):
func = CommandHandler(func)
new(name, help=help, aliases=aliases, require_subcommand=required_subcommand,
arg_fallthrough=arg_fallthrough)(func)
func.__mb_parent__ = self
func.__mb_event_handler__ = False
self.__mb_subcommands__.append(func)
return func
return decorator
def new(name: PrefixType = None, *, help: str = None, aliases: AliasesType = None,
event_type: EventType = EventType.ROOM_MESSAGE, require_subcommand: bool = True,
arg_fallthrough: bool = True) -> CommandHandlerDecorator:
def decorator(func: Union[CommandHandler, CommandHandlerFunc]) -> CommandHandler:
if not isinstance(func, CommandHandler):
func = CommandHandler(func)
func.__mb_help__ = help
if name:
if callable(name):
if len(inspect.getfullargspec(name).args) == 0:
func.__mb_get_name__ = lambda self: name()
else:
func.__mb_get_name__ = name
else:
func.__mb_get_name__ = lambda self: name
else:
func.__mb_get_name__ = lambda self: func.__name__
if callable(aliases):
if len(inspect.getfullargspec(aliases).args) == 1:
func.__mb_is_command_match__ = lambda self, val: aliases(val)
else:
func.__mb_is_command_match__ = aliases
elif isinstance(aliases, (list, set, tuple)):
func.__mb_is_command_match__ = lambda self, val: (val == func.__mb_name__
or val in aliases)
else:
func.__mb_is_command_match__ = lambda self, val: val == func.__mb_name__
# Decorators are executed last to first, so we reverse the argument list.
func.__mb_arguments__.reverse()
func.__mb_require_subcommand__ = require_subcommand
func.__mb_arg_fallthrough__ = arg_fallthrough
func.__mb_event_type__ = event_type
return func
return decorator
class ArgumentSyntaxError(ValueError):
def __init__(self, message: str, show_usage: bool = True) -> None:
super().__init__(message)
self.message = message
self.show_usage = show_usage
class Argument(ABC):
def __init__(self, name: str, label: str = None, *, required: bool = False,
pass_raw: bool = False) -> None:
self.name = name
self.label = label or name
self.required = required
self.pass_raw = pass_raw
@abstractmethod
def match(self, val: str) -> Tuple[str, Any]:
pass
def __call__(self, func: Union[CommandHandler, CommandHandlerFunc]) -> CommandHandler:
if not isinstance(func, CommandHandler):
func = CommandHandler(func)
func.__mb_arguments__.append(self)
return func
class RegexArgument(Argument):
def __init__(self, name: str, label: str = None, *, required: bool = False,
pass_raw: bool = False, matches: str = None) -> None:
super().__init__(name, label, required=required, pass_raw=pass_raw)
matches = f"^{matches}" if self.pass_raw else f"^{matches}$"
self.regex = re.compile(matches)
def match(self, val: str) -> Tuple[str, Any]:
orig_val = val
if not self.pass_raw:
val = val.split(" ")[0]
match = self.regex.match(val)
if match:
return (orig_val[:match.pos] + orig_val[match.endpos:],
match.groups() or val[match.pos:match.endpos])
return orig_val, None
class CustomArgument(Argument):
def __init__(self, name: str, label: str = None, *, required: bool = False,
pass_raw: bool = False, matcher: Callable[[str], Any]) -> None:
super().__init__(name, label, required=required, pass_raw=pass_raw)
self.matcher = matcher
def match(self, val: str) -> Tuple[str, Any]:
if self.pass_raw:
return self.matcher(val)
orig_val = val
val = val.split(" ")[0]
res = self.matcher(val)
if res:
return orig_val[len(val):], res
return orig_val, None
class SimpleArgument(Argument):
def match(self, val: str) -> Tuple[str, Any]:
if self.pass_raw:
return "", val
res = val.split(" ")[0]
return val[len(res):], res
def argument(name: str, label: str = None, *, required: bool = True, matches: Optional[str] = None,
parser: Optional[Callable[[str], Any]] = None, pass_raw: bool = False
) -> CommandHandlerDecorator:
if matches:
return RegexArgument(name, label, required=required, matches=matches, pass_raw=pass_raw)
elif parser:
return CustomArgument(name, label, required=required, matcher=parser, pass_raw=pass_raw)
else:
return SimpleArgument(name, label, required=required, pass_raw=pass_raw)
def passive(regex: Union[str, Pattern], *, msgtypes: Sequence[MessageType] = (MessageType.TEXT,),
field: Callable[[MaubotMessageEvent], str] = lambda evt: evt.content.body,
event_type: EventType = EventType.ROOM_MESSAGE, multiple: bool = False
) -> PassiveCommandHandlerDecorator:
if not isinstance(regex, Pattern):
regex = re.compile(regex)
def decorator(func: CommandHandlerFunc) -> CommandHandlerFunc:
combine = None
if hasattr(func, "__mb_passive_orig__"):
combine = func
func = func.__mb_passive_orig__
@event.on(event_type)
@functools.wraps(func)
async def replacement(self, evt: MaubotMessageEvent = None) -> None:
if not evt and isinstance(self, MaubotMessageEvent):
evt = self
self = None
if evt.sender == evt.client.mxid:
return
elif msgtypes and evt.content.msgtype not in msgtypes:
return
data = field(evt)
if multiple:
val = [(data[match.pos:match.endpos], *match.groups())
for match in regex.finditer(data)]
else:
match = regex.match(data)
if match:
val = (data[match.pos:match.endpos], *match.groups())
else:
val = None
if val:
if self:
await func(self, evt, val)
else:
await func(evt, val)
if combine:
orig_replacement = replacement
@event.on(event_type)
@functools.wraps(func)
async def replacement(self, evt: MaubotMessageEvent = None) -> None:
await asyncio.gather(combine(self, evt), orig_replacement(self, evt))
replacement.__mb_passive_orig__ = func
return replacement
return decorator

34
maubot/handlers/event.py Normal file
View file

@ -0,0 +1,34 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Callable, Union, NewType
from mautrix.types import EventType
from mautrix.client import EventHandler
EventHandlerDecorator = NewType("EventHandlerDecorator", Callable[[EventHandler], EventHandler])
def on(var: Union[EventType, EventHandler]) -> Union[EventHandlerDecorator, EventHandler]:
def decorator(func: EventHandler) -> EventHandler:
func.__mb_event_handler__ = True
if isinstance(var, EventType):
func.__mb_event_type__ = var
else:
func.__mb_event_type__ = EventType.ALL
return func
return decorator if isinstance(var, EventType) else decorator(var)

View file

@ -15,12 +15,14 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, List, Optional
from asyncio import AbstractEventLoop
import os.path
import logging
import io
from sqlalchemy.orm import Session
from ruamel.yaml.comments import CommentedMap
from ruamel.yaml import YAML
from sqlalchemy.orm import Session
import sqlalchemy as sql
from mautrix.util.config import BaseProxyConfig, RecursiveDict
from mautrix.types import UserID
@ -133,8 +135,12 @@ class PluginInstance:
except (FileNotFoundError, KeyError):
self.base_cfg = None
self.config = config_class(self.load_config, lambda: self.base_cfg, self.save_config)
self.plugin = cls(self.client.client, self.loop, self.client.http_client, self.id,
self.log, self.config, self.mb_config["plugin_directories.db"])
db = None
if self.loader.meta.database:
db_path = os.path.join(self.mb_config["plugin_directories.db"], self.id)
db = sql.create_engine(f"sqlite:///{db_path}.db")
self.plugin = cls(client=self.client.client, loop=self.loop, http=self.client.http_client,
instance_id=self.id, log=self.log, config=self.config, database=db)
try:
await self.plugin.start()
except Exception:

View file

@ -22,6 +22,7 @@ from packaging.version import Version, InvalidVersion
from mautrix.client.api.types.util import (SerializableAttrs, SerializerError, serializer,
deserializer)
from ..__meta__ import __version__
from ..plugin_base import Plugin
if TYPE_CHECKING:
@ -51,9 +52,12 @@ def deserialize_version(version: str) -> Version:
class PluginMeta(SerializableAttrs['PluginMeta']):
id: str
version: Version
license: str
modules: List[str]
main_class: str
maubot: Version = Version(__version__)
database: bool = False
license: str = ""
extra_files: List[str] = []
dependencies: List[str] = []
soft_dependencies: List[str] = []

View file

@ -13,18 +13,15 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, List, Union, Callable, Awaitable, Optional, Tuple
from typing import Union, Awaitable, Optional, Tuple
from markdown.extensions import Extension
import markdown as md
import attr
from mautrix import Client as MatrixClient
from mautrix.util.formatter import parse_html
from mautrix.client import EventHandler
from mautrix.types import (EventType, MessageEvent, Event, EventID, RoomID, MessageEventContent,
MessageType, TextMessageEventContent, Format, RelatesTo)
from .command_spec import ParsedCommand, CommandSpec
MessageType, TextMessageEventContent, Format, RelatesTo, StateEvent)
class EscapeHTML(Extension):
@ -42,12 +39,12 @@ def parse_markdown(markdown: str, allow_html: bool = False) -> Tuple[str, str]:
class MaubotMessageEvent(MessageEvent):
_client: MatrixClient
client: MatrixClient
def __init__(self, base: MessageEvent, client: MatrixClient):
super().__init__(**{a.name.lstrip("_"): getattr(base, a.name)
for a in attr.fields(MessageEvent)})
self._client = client
self.client = client
def respond(self, content: Union[str, MessageEventContent],
event_type: EventType = EventType.ROOM_MESSAGE,
@ -59,7 +56,7 @@ class MaubotMessageEvent(MessageEvent):
content.body, content.formatted_body = parse_markdown(content.body)
if reply:
content.set_reply(self)
return self._client.send_message_event(self.room_id, event_type, content)
return self.client.send_message_event(self.room_id, event_type, content)
def reply(self, content: Union[str, MessageEventContent],
event_type: EventType = EventType.ROOM_MESSAGE,
@ -67,18 +64,10 @@ class MaubotMessageEvent(MessageEvent):
return self.respond(content, event_type, markdown, reply=True)
def mark_read(self) -> Awaitable[None]:
return self._client.send_receipt(self.room_id, self.event_id, "m.read")
return self.client.send_receipt(self.room_id, self.event_id, "m.read")
class MaubotMatrixClient(MatrixClient):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.command_handlers: Dict[str, List[EventHandler]] = {}
self.commands: List[ParsedCommand] = []
self.command_specs: Dict[str, CommandSpec] = {}
self.add_event_handler(self._command_event_handler, EventType.ROOM_MESSAGE)
def send_markdown(self, room_id: RoomID, markdown: str, msgtype: MessageType = MessageType.TEXT,
relates_to: Optional[RelatesTo] = None, **kwargs) -> Awaitable[EventID]:
content = TextMessageEventContent(msgtype=msgtype, format=Format.HTML)
@ -87,60 +76,17 @@ class MaubotMatrixClient(MatrixClient):
content.relates_to = relates_to
return self.send_message(room_id, content, **kwargs)
def set_command_spec(self, plugin_id: str, spec: CommandSpec) -> None:
self.command_specs[plugin_id] = spec
self._reparse_command_specs()
def _reparse_command_specs(self) -> None:
self.commands = [parsed_command
for spec in self.command_specs.values()
for parsed_command in spec.parse()]
def remove_command_spec(self, plugin_id: str) -> None:
try:
del self.command_specs[plugin_id]
self._reparse_command_specs()
except KeyError:
pass
async def _command_event_handler(self, evt: MessageEvent) -> None:
if evt.sender == self.mxid or evt.content.msgtype == MessageType.NOTICE:
return
for command in self.commands:
if command.match(evt):
await self._trigger_command(command, evt)
return
async def _trigger_command(self, command: ParsedCommand, evt: MessageEvent) -> None:
for handler in self.command_handlers.get(command.name, []):
await handler(evt)
def on(self, var: Union[EventHandler, EventType, str]
) -> Union[EventHandler, Callable[[EventHandler], EventHandler]]:
if isinstance(var, str):
def decorator(func: EventHandler) -> EventHandler:
self.add_command_handler(var, func)
return func
return decorator
return super().on(var)
def add_command_handler(self, command: str, handler: EventHandler) -> None:
self.command_handlers.setdefault(command, []).append(handler)
def remove_command_handler(self, command: str, handler: EventHandler) -> None:
try:
self.command_handlers[command].remove(handler)
except (KeyError, ValueError):
pass
async def call_handlers(self, event: Event) -> None:
if isinstance(event, MessageEvent):
event = MaubotMessageEvent(event, self)
else:
event.client = self
return await super().call_handlers(event)
async def get_event(self, room_id: RoomID, event_id: EventID) -> Event:
event = await super().get_event(room_id, event_id)
if isinstance(event, MessageEvent):
return MaubotMessageEvent(event, self)
else:
event.client = self
return event

View file

@ -14,21 +14,18 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Type, Optional, TYPE_CHECKING
from abc import ABC
from logging import Logger
from abc import ABC, abstractmethod
from asyncio import AbstractEventLoop
from aiohttp import ClientSession
import os.path
import functools
from sqlalchemy.engine.base import Engine
import sqlalchemy as sql
from aiohttp import ClientSession
if TYPE_CHECKING:
from .client import MaubotMatrixClient
from .command_spec import CommandSpec
from mautrix.types import Event
from mautrix.util.config import BaseProxyConfig
DatabaseNotConfigured = ValueError("A database for this maubot instance has not been configured.")
from .client import MaubotMatrixClient
class Plugin(ABC):
@ -37,33 +34,30 @@ class Plugin(ABC):
log: Logger
loop: AbstractEventLoop
config: Optional['BaseProxyConfig']
database: Optional[Engine]
def __init__(self, client: 'MaubotMatrixClient', loop: AbstractEventLoop, http: ClientSession,
plugin_instance_id: str, log: Logger, config: Optional['BaseProxyConfig'],
db_base_path: str) -> None:
instance_id: str, log: Logger, config: Optional['BaseProxyConfig'],
database: Optional[Engine]) -> None:
self.client = client
self.loop = loop
self.http = http
self.id = plugin_instance_id
self.id = instance_id
self.log = log
self.config = config
self.__db_base_path = db_base_path
self.database = database
self._handlers_at_startup = []
def request_db_engine(self) -> Optional[Engine]:
if not self.__db_base_path:
raise DatabaseNotConfigured
return sql.create_engine(f"sqlite:///{os.path.join(self.__db_base_path, self.id)}.db")
def set_command_spec(self, spec: 'CommandSpec') -> None:
self.client.set_command_spec(self.id, spec)
@abstractmethod
async def start(self) -> None:
pass
for key in dir(self):
val = getattr(self, key)
if hasattr(val, "__mb_event_handler__") and val.__mb_event_handler__:
self._handlers_at_startup.append((val, val.__mb_event_type__))
self.client.add_event_handler(val.__mb_event_type__, val)
@abstractmethod
async def stop(self) -> None:
pass
for func, event_type in self._handlers_at_startup:
self.client.remove_event_handler(event_type, func)
@classmethod
def get_config_class(cls) -> Optional[Type['BaseProxyConfig']]: