maubot_reactbot/reactbot.py

189 lines
6.8 KiB
Python
Raw Normal View History

2019-06-21 11:49:50 +00:00
# reminder - A maubot plugin that reacts to messages that match predefined rules.
# Copyright (C) 2019 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import NewType, Optional, Pattern, Match, Union, Dict, List, Tuple, Set, Type, Any
from itertools import chain
import copy
2019-06-21 11:49:50 +00:00
import re
from attr import dataclass
from jinja2 import Template as JinjaTemplate
from mautrix.types import RoomID, EventType, Event
2019-06-21 11:49:50 +00:00
from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper
2019-06-22 23:08:30 +00:00
from maubot import Plugin, MessageEvent
2019-06-21 11:49:50 +00:00
from maubot.handlers import event
class Config(BaseProxyConfig):
def do_update(self, helper: ConfigUpdateHelper) -> None:
helper.copy("rules")
helper.copy("templates")
class Key(str):
pass
Index = NewType("Index", Union[str, int, Key])
2019-06-22 23:08:30 +00:00
variable_regex = re.compile(r"\$\${([0-9A-Za-z-_]+)}")
@dataclass
class Template:
type: EventType
variables: Dict[str, JinjaTemplate]
content: Dict[str, Any]
_variable_locations: List[Tuple[Index, ...]] = None
def init(self) -> 'Template':
self._variable_locations = []
self._map_variable_locations((), self.content)
return self
def _map_variable_locations(self, path: Tuple[Index, ...], data: Any) -> None:
if isinstance(data, list):
for i, v in enumerate(data):
self._map_variable_locations((*path, i), v)
elif isinstance(data, dict):
for k, v in data.items():
if variable_regex.match(k):
self._variable_locations.append((*path, Key(k)))
self._map_variable_locations((*path, k), v)
elif isinstance(data, str):
if variable_regex.match(data):
self._variable_locations.append(path)
@classmethod
def _recurse(cls, content: Any, path: Tuple[Index, ...]) -> Any:
if len(path) == 0:
return content
return cls._recurse(content[path[0]], path[1:])
@staticmethod
def _replace_variables(tpl: str, variables: Dict[str, Any]) -> str:
for match in variable_regex.finditer(tpl):
val = variables[match.group(1)]
tpl = tpl[:match.start()] + val + tpl[match.end():]
return tpl
def execute(self, evt: Event, rule_vars: Dict[str, JinjaTemplate], extra_vars: Dict[str, str]
) -> Dict[str, Any]:
variables = {**{name: template.render(event=evt)
for name, template in chain(self.variables.items(), rule_vars.items())},
**extra_vars}
content = copy.deepcopy(self.content)
for path in self._variable_locations:
2019-06-22 23:56:44 +00:00
data: Dict[str, Any] = self._recurse(content, path[:-1])
key = path[-1]
if isinstance(key, Key):
key = str(key)
data[self._replace_variables(key, variables)] = data.pop(key)
else:
data[key] = self._replace_variables(data[key], variables)
return content
2019-06-21 11:49:50 +00:00
@dataclass
class Rule:
rooms: Set[RoomID]
matches: List[Pattern]
template: Template
type: Optional[EventType]
variables: Dict[str, JinjaTemplate]
def match(self, evt: MessageEvent) -> Optional[Match]:
if len(self.rooms) > 0 and evt.room_id not in self.rooms:
return None
for pattern in self.matches:
match = pattern.search(evt.content.body)
if match:
return match
return None
2019-06-21 11:49:50 +00:00
async def execute(self, evt: MessageEvent, match: Match) -> None:
content = self.template.execute(evt=evt, rule_vars=self.variables,
extra_vars={str(i): val for i, val in
enumerate(match.groups())})
await evt.client.send_message_event(evt.room_id, self.type or self.template.type, content)
2019-06-21 11:49:50 +00:00
2019-06-22 23:56:44 +00:00
class ConfigError(Exception):
pass
2019-06-21 11:49:50 +00:00
class ReactBot(Plugin):
rules: Dict[str, Rule]
templates: Dict[str, Template]
2019-06-21 11:49:50 +00:00
@classmethod
def get_config_class(cls) -> Type[BaseProxyConfig]:
return Config
async def start(self) -> None:
await super().start()
self.rules = {}
self.templates = {}
2019-06-21 11:49:50 +00:00
self.on_external_config_update()
@staticmethod
def _parse_variables(data: Dict[str, Any]) -> Dict[str, JinjaTemplate]:
return {name: JinjaTemplate(var_tpl) for name, var_tpl
in data.get("variables", {}).items()}
2019-06-22 23:56:44 +00:00
def _make_template(self, name: str, tpl: Dict[str, Any]) -> Template:
try:
return Template(type=EventType.find(tpl.get("type", "m.room.message")),
variables=self._parse_variables(tpl),
content=tpl.get("content", {})).init()
except Exception as e:
raise ConfigError(f"Failed to load {name}") from e
def _make_rule(self, name: str, rule: Dict[str, Any]) -> Rule:
try:
return Rule(rooms=set(rule.get("rooms", [])),
matches=[re.compile(match) for match in rule.get("matches")],
type=EventType.find(rule["type"]) if "type" in rule else None,
template=self.templates[rule["template"]],
variables=self._parse_variables(rule))
except Exception as e:
raise ConfigError(f"Failed to load {name}") from e
2019-06-21 11:49:50 +00:00
def on_external_config_update(self) -> None:
self.config.load_and_update()
2019-06-22 23:56:44 +00:00
try:
self.templates = {name: self._make_template(name, tpl)
for name, tpl in self.config["templates"].items()}
self.rules = {name: self._make_rule(name, rule)
for name, rule in self.config["rules"].items()}
except ConfigError:
self.log.exception("Failed to load config")
2019-06-21 11:49:50 +00:00
@event.on(EventType.ROOM_MESSAGE)
async def event_handler(self, evt: MessageEvent) -> None:
if evt.sender == self.client.mxid:
return
for name, rule in self.rules.items():
match = rule.match(evt)
if match is not None:
try:
await rule.execute(evt, match)
except Exception:
self.log.exception(f"Failed to execute {name}")
return