diff --git a/base-config.yaml b/base-config.yaml index a63f3b6..a523deb 100644 --- a/base-config.yaml +++ b/base-config.yaml @@ -26,6 +26,9 @@ templates: mimetype: image/png size: 233763 +default_flags: +- ignorecase + rules: twim_cookies: rooms: ["!FPUfgzXYWTKgIrwKxW:matrix.org"] diff --git a/reactbot.py b/reactbot.py index 966936f..87472bc 100644 --- a/reactbot.py +++ b/reactbot.py @@ -33,6 +33,7 @@ class Config(BaseProxyConfig): def do_update(self, helper: ConfigUpdateHelper) -> None: helper.copy("rules") helper.copy("templates") + helper.copy("default_flags") class ConfigError(Exception): @@ -53,11 +54,15 @@ class SimplePattern: _ptm = BlankMatch() matcher: Callable[[str], bool] + ignorecase: bool - def __init__(self, matcher: Callable[[str], bool]) -> None: + def __init__(self, matcher: Callable[[str], bool], ignorecase: bool) -> None: self.matcher = matcher + self.ignorecase = ignorecase - def match(self, val: str) -> BlankMatch: + def search(self, val: str) -> BlankMatch: + if self.ignorecase: + val = val.lower() if self.matcher(val): return self._ptm @@ -135,12 +140,20 @@ class Rule: type: Optional[EventType] variables: Dict[str, JinjaTemplate] + def _check_not_match(self, body: str) -> bool: + for pattern in self.not_matches: + if pattern.search(body): + return True + return False + def match(self, evt: MessageEvent) -> Optional[Match]: if len(self.rooms) > 0 and evt.room_id not in self.rooms: return None for pattern in self.matches: match = pattern.search(evt.content.body) if match: + if self._check_not_match(evt.content.body): + return None return match return None @@ -154,6 +167,7 @@ class Rule: class ReactBot(Plugin): rules: Dict[str, Rule] templates: Dict[str, Template] + default_flags: re.RegexFlag @classmethod def get_config_class(cls) -> Type[BaseProxyConfig]: @@ -178,21 +192,52 @@ class ReactBot(Plugin): except Exception as e: raise ConfigError(f"Failed to load {name}") from e - @staticmethod - def _compile(pattern: str) -> RPattern: - esc = re.escape(pattern) - if esc == pattern: - return SimplePattern(lambda val: pattern in val) - elif pattern[0] == '^' and esc == f"\\^{pattern}": - pattern = pattern[1:] - return SimplePattern(lambda val: val.startswith(pattern)) - elif pattern[-1] == '$' and esc == f"{pattern}\\$": - pattern = pattern[:-1] - return SimplePattern(lambda val: val.endswith(pattern)) - elif pattern[0] == '^' and pattern[-1] == '$' and esc == f"\\^{pattern}\\$": - pattern = pattern[1:-1] - return SimplePattern(lambda val: val == pattern) - return re.compile(pattern) + def _get_flags(self, flags: str) -> re.RegexFlag: + output = self.default_flags + for flag in flags: + flag = flag.lower() + if flag == "i" or flag == "ignorecase": + output |= re.IGNORECASE + elif flag == "s" or flag == "dotall": + output |= re.DOTALL + elif flag == "x" or flag == "verbose": + output |= re.VERBOSE + elif flag == "m" or flag == "multiline": + output |= re.MULTILINE + elif flag == "l" or flag == "locale": + output |= re.LOCALE + elif flag == "u" or flag == "unicode": + output |= re.UNICODE + elif flag == "a" or flag == "ascii": + output |= re.ASCII + return output + + def _compile(self, pattern: str) -> RPattern: + flags = self.default_flags + raw = False + if isinstance(pattern, dict): + flags = self._get_flags(pattern.get("flags", "")) + pattern = pattern["pattern"] + raw = pattern.get("raw", False) + if not flags or flags == re.IGNORECASE: + ignorecase = flags == re.IGNORECASE + s_pattern = pattern.lower() if ignorecase else pattern + esc = "" + if not raw: + esc = re.escape(pattern) + first, last = pattern[0], pattern[-1] + if first == '^' and last == '$' and (raw or esc == f"\\^{pattern[1:-1]}\\$"): + s_pattern = s_pattern[1:-1] + return SimplePattern(lambda val: val == s_pattern, ignorecase=ignorecase) + elif first == '^' and (raw or esc == f"\\^{pattern[1:]}"): + s_pattern = s_pattern[1:] + return SimplePattern(lambda val: val.startswith(s_pattern), ignorecase=ignorecase) + elif last == '$' and (raw or esc == f"{pattern[:-1]}\\$"): + s_pattern = s_pattern[:-1] + return SimplePattern(lambda val: val.endswith(s_pattern), ignorecase=ignorecase) + elif raw or esc == pattern: + return SimplePattern(lambda val: s_pattern in val, ignorecase=ignorecase) + return re.compile(pattern, flags=flags) def _compile_all(self, patterns: List[str]) -> List[RPattern]: return [self._compile(pattern) for pattern in patterns] @@ -211,6 +256,8 @@ class ReactBot(Plugin): def on_external_config_update(self) -> None: self.config.load_and_update() try: + self.default_flags = re.RegexFlag(0) + self.default_flags = self._get_flags(self.config["default_flags"]) self.templates = {name: self._make_template(name, tpl) for name, tpl in self.config["templates"].items()} self.rules = {name: self._make_rule(name, rule)