Add experimental support for rules matching different event fields

This commit is contained in:
Tulir Asokan 2021-04-08 20:28:01 +03:00
parent 821e670fd5
commit fb214d8f0b
5 changed files with 82 additions and 31 deletions

View file

@ -1,4 +1,4 @@
# reminder - A maubot plugin that reacts to messages that match predefined rules.
# reactbot - A maubot plugin that reacts to messages that match predefined rules.
# Copyright (C) 2019 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
@ -13,7 +13,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Callable, List, Dict, Optional
from typing import Callable, List, Dict, Optional, Pattern, Match
import re
@ -27,15 +27,30 @@ class BlankMatch:
return {}
class RegexPattern:
pattern: Pattern
field: Optional[List[str]]
def __init__(self, pattern: Pattern, field: Optional[List[str]] = None) -> None:
self.pattern = pattern
self.field = field
def search(self, val: str) -> Match:
return self.pattern.search(val)
class SimplePattern:
_ptm = BlankMatch()
matcher: Callable[[str], bool]
field: Optional[List[str]]
ignorecase: bool
def __init__(self, matcher: Callable[[str], bool], ignorecase: bool) -> None:
def __init__(self, matcher: Callable[[str], bool], ignorecase: bool,
field: Optional[List[str]] = None) -> None:
self.matcher = matcher
self.ignorecase = ignorecase
self.field = field
def search(self, val: str) -> BlankMatch:
if self.ignorecase:
@ -44,8 +59,8 @@ class SimplePattern:
return self._ptm
@staticmethod
def compile(pattern: str, flags: re.RegexFlag = re.RegexFlag(0), force_raw: bool = False
) -> Optional['SimplePattern']:
def compile(pattern: str, flags: re.RegexFlag = re.RegexFlag(0), force_raw: bool = False,
field: Optional[List[str]] = None) -> Optional['SimplePattern']:
ignorecase = flags == re.IGNORECASE
s_pattern = pattern.lower() if ignorecase else pattern
esc = ""
@ -54,13 +69,15 @@ class SimplePattern:
first, last = pattern[0], pattern[-1]
if first == '^' and last == '$' and (force_raw or esc == f"\\^{pattern[1:-1]}\\$"):
s_pattern = s_pattern[1:-1]
return SimplePattern(lambda val: val == s_pattern, ignorecase=ignorecase)
return SimplePattern(lambda val: val == s_pattern, ignorecase=ignorecase, field=field)
elif first == '^' and (force_raw or esc == f"\\^{pattern[1:]}"):
s_pattern = s_pattern[1:]
return SimplePattern(lambda val: val.startswith(s_pattern), ignorecase=ignorecase)
return SimplePattern(lambda val: val.startswith(s_pattern), ignorecase=ignorecase,
field=field)
elif last == '$' and (force_raw or esc == f"{pattern[:-1]}\\$"):
s_pattern = s_pattern[:-1]
return SimplePattern(lambda val: val.endswith(s_pattern), ignorecase=ignorecase)
return SimplePattern(lambda val: val.endswith(s_pattern), ignorecase=ignorecase,
field=field)
elif force_raw or esc == pattern:
return SimplePattern(lambda val: s_pattern in val, ignorecase=ignorecase)
return SimplePattern(lambda val: s_pattern in val, ignorecase=ignorecase, field=field)
return None