Changes related to rewriting the command handling system

This commit is contained in:
Tulir Asokan 2018-12-16 00:52:54 +02:00
parent 69de2c9d85
commit f104595217
10 changed files with 182 additions and 243 deletions

View file

@ -1,6 +1,9 @@
# This is an example maubot plugin definition file.
# All plugins must include a file like this named "maubot.yaml" in their root directory.
# Target maubot version
maubot: 0.1.0
# The unique ID for the plugin. Java package naming style. (i.e. use your own domain, not xyz.maubot)
id: xyz.maubot.example
@ -24,6 +27,9 @@ modules:
# The main class must extend maubot.Plugin
main_class: HelloWorldBot
# Whether or not instances need a database
database: false
# Extra files that the upcoming build tool should include in the mbp file.
#extra_files:
#- base-config.yaml

View file

@ -1,3 +1,3 @@
from .plugin_base import Plugin
from .command_spec import CommandSpec, Command, PassiveCommand, Argument
from .matrix import MaubotMatrixClient as Client, MaubotMessageEvent as MessageEvent
from .handlers import event, command

View file

@ -1,155 +0,0 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import List, Dict, Pattern, Union, Tuple, Optional, Any
from attr import dataclass
import re
from mautrix.types import MessageEvent, MatchedCommand, MatchedPassiveCommand
from mautrix.client.api.types.util import SerializableAttrs
@dataclass
class Argument(SerializableAttrs['Argument']):
matches: str
required: bool = False
description: str = None
@dataclass
class Command(SerializableAttrs['Command']):
syntax: str
arguments: Dict[str, Argument] = {}
description: str = None
@dataclass
class PassiveCommand(SerializableAttrs['PassiveCommand']):
name: str
matches: str
match_against: str
match_event: MessageEvent = None
class ParsedCommand:
name: str
is_passive: bool
arguments: List[str]
starts_with: str
matches: Pattern
match_against: str
match_event: MessageEvent
def __init__(self, command: Union[PassiveCommand, Command]) -> None:
if isinstance(command, PassiveCommand):
self._init_passive(command)
elif isinstance(command, Command):
self._init_active(command)
else:
raise ValueError("Command parameter must be a Command or a PassiveCommand.")
def _init_passive(self, command: PassiveCommand) -> None:
self.name = command.name
self.is_passive = True
self.match_against = command.match_against
self.matches = re.compile(command.matches, re.UNICODE)
self.match_event = command.match_event
def _init_active(self, command: Command) -> None:
self.name = command.syntax
self.is_passive = False
self.arguments = []
regex_builder = []
sw_builder = []
argument_encountered = False
for word in command.syntax.split(" "):
arg = command.arguments.get(word, None)
if arg is not None and len(word) > 0:
argument_encountered = True
regex = f"({arg.matches})"
if not arg.required:
regex += "?"
self.arguments.append(word)
regex_builder.append(regex)
else:
if not argument_encountered:
sw_builder.append(word)
regex_builder.append(re.escape(word))
self.starts_with = "!" + " ".join(sw_builder)
self.matches = re.compile("^!" + " ".join(regex_builder) + "$", re.UNICODE)
self.match_against = "body"
def match(self, evt: MessageEvent) -> bool:
return self._match_passive(evt) if self.is_passive else self._match_active(evt)
@staticmethod
def _parse_key(key: str) -> Tuple[str, Optional[str]]:
if '.' not in key:
return key, None
key, next_key = key.split('.', 1)
if len(key) > 0 and key[0] == "[":
end_index = next_key.index("]")
key = key[1:] + "." + next_key[:end_index]
next_key = next_key[end_index + 2:] if len(next_key) > end_index + 1 else None
return key, next_key
@classmethod
def _recursive_get(cls, data: Any, key: str) -> Any:
if not data:
return None
key, next_key = cls._parse_key(key)
if next_key is not None:
return cls._recursive_get(data[key], next_key)
return data[key]
def _match_passive(self, evt: MessageEvent) -> bool:
try:
match_against = self._recursive_get(evt.content, self.match_against)
except KeyError:
match_against = None
match_against = match_against or evt.content.body
matches = [[match.string[match.start():match.end()]] + list(match.groups())
for match in self.matches.finditer(match_against)]
if not matches:
return False
if evt.unsigned.passive_command is None:
evt.unsigned.passive_command = {}
evt.unsigned.passive_command[self.name] = MatchedPassiveCommand(captured=matches)
return True
def _match_active(self, evt: MessageEvent) -> bool:
if not evt.content.body.startswith(self.starts_with):
return False
match = self.matches.match(evt.content.body)
if not match:
return False
evt.content.command = MatchedCommand(matched=self.name,
arguments=dict(zip(self.arguments, match.groups())))
return True
@dataclass
class CommandSpec(SerializableAttrs['CommandSpec']):
commands: List[Command] = []
passive_commands: List[PassiveCommand] = []
def __add__(self, other: 'CommandSpec') -> 'CommandSpec':
return CommandSpec(commands=self.commands + other.commands,
passive_commands=self.passive_commands + other.passive_commands)
def parse(self) -> List[ParsedCommand]:
return [ParsedCommand(command) for command in self.commands + self.passive_commands]

View file

@ -0,0 +1 @@
from . import event, command

View file

@ -0,0 +1,15 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

116
maubot/handlers/event.py Normal file
View file

@ -0,0 +1,116 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Callable, Union, NewType, Any, Tuple, Optional
import functools
import re
from mautrix.types import EventType, Event, EventContent, MessageEvent, MessageEventContent
from mautrix.client import EventHandler
EventHandlerDecorator = NewType("EventHandlerDecorator", Callable[[EventHandler], EventHandler])
def handler(var: Union[EventType, EventHandler]) -> Union[EventHandlerDecorator, EventHandler]:
def decorator(func: EventHandler) -> EventHandler:
func.__mb_event_handler__ = True
if isinstance(var, EventType):
func.__mb_event_type__ = var
else:
func.__mb_event_type__ = EventType.ALL
return func
if isinstance(var, EventType):
return decorator
else:
decorator(var)
class Field:
body: Callable[[MessageEventContent], str] = lambda content: content.body
msgtype: Callable[[MessageEventContent], str] = lambda content: content.msgtype
def _parse_key(key: str) -> Tuple[str, Optional[str]]:
if '.' not in key:
return key, None
key, next_key = key.split('.', 1)
if len(key) > 0 and key[0] == "[":
end_index = next_key.index("]")
key = key[1:] + "." + next_key[:end_index]
next_key = next_key[end_index + 2:] if len(next_key) > end_index + 1 else None
return key, next_key
def _recursive_get(data: EventContent, key: str) -> Any:
key, next_key = _parse_key(key)
if next_key is not None:
next_data = data.get(key, None)
if next_data is None:
return None
return _recursive_get(next_data, next_key)
return data.get(key, None)
def _find_content_field(content: EventContent, field: str) -> Any:
val = _recursive_get(content, field)
if not val and hasattr(content, "unrecognized_"):
val = _recursive_get(content.unrecognized_, field)
return val
def handle_own_events(func: EventHandler) -> EventHandler:
func.__mb_handle_own_events__ = True
def filter_content(field: Union[str, Callable[[EventContent], Any]], substr: str = None,
pattern: str = None, exact: bool = False):
if substr and pattern:
raise ValueError("You can only provide one of substr or pattern.")
elif not substr and not pattern:
raise ValueError("You must provide either substr or pattern.")
if not callable(field):
field = functools.partial(_find_content_field, field=field)
if substr:
def func(evt: MessageEvent) -> bool:
val = field(evt.content)
if val is None:
return False
elif substr in val:
return True
else:
pattern = re.compile(pattern)
def func(evt: MessageEvent) -> bool:
val = field(evt.content)
if val is None:
return False
elif pattern.match(val):
return True
return filter(func)
def filter(func: Callable[[MessageEvent], bool]) -> EventHandlerDecorator:
def decorator(func: EventHandler) -> EventHandler:
if not hasattr(func, "__mb_event_filters__"):
func.__mb_event_filters__ = []
func.__mb_event_filters__.append(func)
return func
return decorator

View file

@ -15,12 +15,14 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, List, Optional
from asyncio import AbstractEventLoop
import os.path
import logging
import io
from sqlalchemy.orm import Session
from ruamel.yaml.comments import CommentedMap
from ruamel.yaml import YAML
from sqlalchemy.orm import Session
import sqlalchemy as sql
from mautrix.util.config import BaseProxyConfig, RecursiveDict
from mautrix.types import UserID
@ -133,8 +135,12 @@ class PluginInstance:
except (FileNotFoundError, KeyError):
self.base_cfg = None
self.config = config_class(self.load_config, lambda: self.base_cfg, self.save_config)
self.plugin = cls(self.client.client, self.loop, self.client.http_client, self.id,
self.log, self.config, self.mb_config["plugin_directories.db"])
db = None
if self.loader.meta.database:
db_path = os.path.join(self.mb_config["plugin_directories.db"], self.id)
db = sql.create_engine(f"sqlite:///{db_path}.db")
self.plugin = cls(client=self.client.client, loop=self.loop, http=self.client.http_client,
instance_id=self.id, log=self.log, config=self.config, database=db)
try:
await self.plugin.start()
except Exception:

View file

@ -22,6 +22,7 @@ from packaging.version import Version, InvalidVersion
from mautrix.client.api.types.util import (SerializableAttrs, SerializerError, serializer,
deserializer)
from ..__meta__ import __version__
from ..plugin_base import Plugin
if TYPE_CHECKING:
@ -51,9 +52,12 @@ def deserialize_version(version: str) -> Version:
class PluginMeta(SerializableAttrs['PluginMeta']):
id: str
version: Version
license: str
modules: List[str]
main_class: str
maubot: Version = Version(__version__)
database: bool = False
license: str = ""
extra_files: List[str] = []
dependencies: List[str] = []
soft_dependencies: List[str] = []

View file

@ -13,19 +13,16 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, List, Union, Callable, Awaitable, Optional, Tuple
from typing import Union, Awaitable, Optional, Tuple
from markdown.extensions import Extension
import markdown as md
import attr
from mautrix import Client as MatrixClient
from mautrix.util.formatter import parse_html
from mautrix.client import EventHandler
from mautrix.types import (EventType, MessageEvent, Event, EventID, RoomID, MessageEventContent,
MessageType, TextMessageEventContent, Format, RelatesTo)
from .command_spec import ParsedCommand, CommandSpec
class EscapeHTML(Extension):
def extendMarkdown(self, md):
@ -71,14 +68,6 @@ class MaubotMessageEvent(MessageEvent):
class MaubotMatrixClient(MatrixClient):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.command_handlers: Dict[str, List[EventHandler]] = {}
self.commands: List[ParsedCommand] = []
self.command_specs: Dict[str, CommandSpec] = {}
self.add_event_handler(self._command_event_handler, EventType.ROOM_MESSAGE)
def send_markdown(self, room_id: RoomID, markdown: str, msgtype: MessageType = MessageType.TEXT,
relates_to: Optional[RelatesTo] = None, **kwargs) -> Awaitable[EventID]:
content = TextMessageEventContent(msgtype=msgtype, format=Format.HTML)
@ -87,53 +76,6 @@ class MaubotMatrixClient(MatrixClient):
content.relates_to = relates_to
return self.send_message(room_id, content, **kwargs)
def set_command_spec(self, plugin_id: str, spec: CommandSpec) -> None:
self.command_specs[plugin_id] = spec
self._reparse_command_specs()
def _reparse_command_specs(self) -> None:
self.commands = [parsed_command
for spec in self.command_specs.values()
for parsed_command in spec.parse()]
def remove_command_spec(self, plugin_id: str) -> None:
try:
del self.command_specs[plugin_id]
self._reparse_command_specs()
except KeyError:
pass
async def _command_event_handler(self, evt: MessageEvent) -> None:
if evt.sender == self.mxid or evt.content.msgtype == MessageType.NOTICE:
return
for command in self.commands:
if command.match(evt):
await self._trigger_command(command, evt)
return
async def _trigger_command(self, command: ParsedCommand, evt: MessageEvent) -> None:
for handler in self.command_handlers.get(command.name, []):
await handler(evt)
def on(self, var: Union[EventHandler, EventType, str]
) -> Union[EventHandler, Callable[[EventHandler], EventHandler]]:
if isinstance(var, str):
def decorator(func: EventHandler) -> EventHandler:
self.add_command_handler(var, func)
return func
return decorator
return super().on(var)
def add_command_handler(self, command: str, handler: EventHandler) -> None:
self.command_handlers.setdefault(command, []).append(handler)
def remove_command_handler(self, command: str, handler: EventHandler) -> None:
try:
self.command_handlers[command].remove(handler)
except (KeyError, ValueError):
pass
async def call_handlers(self, event: Event) -> None:
if isinstance(event, MessageEvent):
event = MaubotMessageEvent(event, self)

View file

@ -14,21 +14,18 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Type, Optional, TYPE_CHECKING
from abc import ABC
from logging import Logger
from abc import ABC, abstractmethod
from asyncio import AbstractEventLoop
from aiohttp import ClientSession
import os.path
import functools
from sqlalchemy.engine.base import Engine
import sqlalchemy as sql
from aiohttp import ClientSession
if TYPE_CHECKING:
from .client import MaubotMatrixClient
from .command_spec import CommandSpec
from mautrix.types import Event
from mautrix.util.config import BaseProxyConfig
DatabaseNotConfigured = ValueError("A database for this maubot instance has not been configured.")
from .client import MaubotMatrixClient
class Plugin(ABC):
@ -37,33 +34,40 @@ class Plugin(ABC):
log: Logger
loop: AbstractEventLoop
config: Optional['BaseProxyConfig']
database: Optional[Engine]
def __init__(self, client: 'MaubotMatrixClient', loop: AbstractEventLoop, http: ClientSession,
plugin_instance_id: str, log: Logger, config: Optional['BaseProxyConfig'],
db_base_path: str) -> None:
instance_id: str, log: Logger, config: Optional['BaseProxyConfig'],
database: Optional[Engine]) -> None:
self.client = client
self.loop = loop
self.http = http
self.id = plugin_instance_id
self.id = instance_id
self.log = log
self.config = config
self.__db_base_path = db_base_path
self.database = database
self._handlers_at_startup = []
def request_db_engine(self) -> Optional[Engine]:
if not self.__db_base_path:
raise DatabaseNotConfigured
return sql.create_engine(f"sqlite:///{os.path.join(self.__db_base_path, self.id)}.db")
def set_command_spec(self, spec: 'CommandSpec') -> None:
self.client.set_command_spec(self.id, spec)
@abstractmethod
async def start(self) -> None:
pass
for key in dir(self):
val = getattr(self, key)
if hasattr(val, "__mb_event_handler__"):
handle_own_events = hasattr(val, "__mb_handle_own_events__")
@functools.wraps(val)
async def handler(event: Event) -> None:
if not handle_own_events and getattr(event, "sender", "") == self.client.mxid:
return
for filter in val.__mb_event_filters__:
if not filter(event):
return
await val(event)
self._handlers_at_startup.append((handler, val.__mb_event_type__))
self.client.add_event_handler(val.__mb_event_type__, handler)
@abstractmethod
async def stop(self) -> None:
pass
for func, event_type in self._handlers_at_startup:
self.client.remove_event_handler(event_type, func)
@classmethod
def get_config_class(cls) -> Optional[Type['BaseProxyConfig']]: