maubot_reactbot/reactbot/config.py

142 lines
5.4 KiB
Python

# reminder - A maubot plugin that reacts to messages that match predefined rules.
# Copyright (C) 2019 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Any, Dict, List, Union
import re
from jinja2 import Template as JinjaStringTemplate
from jinja2.nativetypes import NativeTemplate as JinjaNativeTemplate
from mautrix.types import EventType
from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper
from .rule import RPattern, Rule
from .simplepattern import SimplePattern
from .template import Template
InputPattern = Union[str, Dict[str, str]]
class Config(BaseProxyConfig):
rules: Dict[str, Rule]
templates: Dict[str, Template]
default_flags: re.RegexFlag
def do_update(self, helper: ConfigUpdateHelper) -> None:
helper.copy("rules")
helper.copy("templates")
helper.copy("default_flags")
helper.copy("antispam.user.max")
helper.copy("antispam.user.delay")
helper.copy("antispam.room.max")
helper.copy("antispam.room.delay")
def parse_data(self) -> None:
self.default_flags = re.RegexFlag(0)
self.templates = {}
self.rules = {}
self.default_flags = self._get_flags(self["default_flags"])
self.templates = {
name: self._make_template(name, tpl) for name, tpl in self["templates"].items()
}
self.rules = {name: self._make_rule(name, rule) for name, rule in self["rules"].items()}
def _make_rule(self, name: str, rule: Dict[str, Any]) -> Rule:
try:
return Rule(
rooms=set(rule.get("rooms", [])),
not_rooms=set(rule.get("not_rooms", [])),
matches=self._compile_all(rule["matches"]),
not_matches=self._compile_all(rule.get("not_matches", [])),
type=EventType.find(rule["type"]) if "type" in rule else None,
template=self.templates[rule["template"]],
variables=self._parse_variables(rule),
)
except Exception as e:
raise ConfigError(f"Failed to load {name}") from e
def _make_template(self, name: str, tpl: Dict[str, Any]) -> Template:
try:
return Template(
type=EventType.find(tpl.get("type", "m.room.message")),
variables=self._parse_variables(tpl),
content=self._parse_content(tpl.get("content", None)),
).init()
except Exception as e:
raise ConfigError(f"Failed to load {name}") from e
def _compile_all(self, patterns: Union[InputPattern, List[InputPattern]]) -> List[RPattern]:
if isinstance(patterns, list):
return [self._compile(pattern) for pattern in patterns]
else:
return [self._compile(patterns)]
def _compile(self, pattern: InputPattern) -> RPattern:
flags = self.default_flags
raw = None
if isinstance(pattern, dict):
flags = self._get_flags(pattern["flags"]) if "flags" in pattern else flags
raw = pattern.get("raw", False)
pattern = pattern["pattern"]
if raw is not False and (not flags & re.MULTILINE or raw is True):
return SimplePattern.compile(pattern, flags, raw) or re.compile(pattern, flags=flags)
return re.compile(pattern, flags=flags)
@staticmethod
def _parse_variables(data: Dict[str, Any]) -> Dict[str, Any]:
return {
name: (
JinjaNativeTemplate(var_tpl)
if isinstance(var_tpl, str) and var_tpl.startswith("{{")
else var_tpl
)
for name, var_tpl in data.get("variables", {}).items()
}
@staticmethod
def _parse_content(
content: Union[Dict[str, Any], str]
) -> Union[Dict[str, Any], JinjaStringTemplate]:
if not content:
return {}
elif isinstance(content, str):
return JinjaStringTemplate(content)
return content
@staticmethod
def _get_flags(flags: Union[str, List[str]]) -> re.RegexFlag:
output = re.RegexFlag(0)
for flag in flags:
flag = flag.lower()
if flag == "i" or flag == "ignorecase":
output |= re.IGNORECASE
elif flag == "s" or flag == "dotall":
output |= re.DOTALL
elif flag == "x" or flag == "verbose":
output |= re.VERBOSE
elif flag == "m" or flag == "multiline":
output |= re.MULTILINE
elif flag == "l" or flag == "locale":
output |= re.LOCALE
elif flag == "u" or flag == "unicode":
output |= re.UNICODE
elif flag == "a" or flag == "ascii":
output |= re.ASCII
return output
class ConfigError(Exception):
pass