Add support for regex flags

This commit is contained in:
Tulir Asokan 2019-06-23 03:45:00 +03:00
parent 3992db4464
commit 98c3bfb252
2 changed files with 67 additions and 17 deletions

View file

@ -26,6 +26,9 @@ templates:
mimetype: image/png mimetype: image/png
size: 233763 size: 233763
default_flags:
- ignorecase
rules: rules:
twim_cookies: twim_cookies:
rooms: ["!FPUfgzXYWTKgIrwKxW:matrix.org"] rooms: ["!FPUfgzXYWTKgIrwKxW:matrix.org"]

View file

@ -33,6 +33,7 @@ class Config(BaseProxyConfig):
def do_update(self, helper: ConfigUpdateHelper) -> None: def do_update(self, helper: ConfigUpdateHelper) -> None:
helper.copy("rules") helper.copy("rules")
helper.copy("templates") helper.copy("templates")
helper.copy("default_flags")
class ConfigError(Exception): class ConfigError(Exception):
@ -53,11 +54,15 @@ class SimplePattern:
_ptm = BlankMatch() _ptm = BlankMatch()
matcher: Callable[[str], bool] matcher: Callable[[str], bool]
ignorecase: bool
def __init__(self, matcher: Callable[[str], bool]) -> None: def __init__(self, matcher: Callable[[str], bool], ignorecase: bool) -> None:
self.matcher = matcher self.matcher = matcher
self.ignorecase = ignorecase
def match(self, val: str) -> BlankMatch: def search(self, val: str) -> BlankMatch:
if self.ignorecase:
val = val.lower()
if self.matcher(val): if self.matcher(val):
return self._ptm return self._ptm
@ -135,12 +140,20 @@ class Rule:
type: Optional[EventType] type: Optional[EventType]
variables: Dict[str, JinjaTemplate] variables: Dict[str, JinjaTemplate]
def _check_not_match(self, body: str) -> bool:
for pattern in self.not_matches:
if pattern.search(body):
return True
return False
def match(self, evt: MessageEvent) -> Optional[Match]: def match(self, evt: MessageEvent) -> Optional[Match]:
if len(self.rooms) > 0 and evt.room_id not in self.rooms: if len(self.rooms) > 0 and evt.room_id not in self.rooms:
return None return None
for pattern in self.matches: for pattern in self.matches:
match = pattern.search(evt.content.body) match = pattern.search(evt.content.body)
if match: if match:
if self._check_not_match(evt.content.body):
return None
return match return match
return None return None
@ -154,6 +167,7 @@ class Rule:
class ReactBot(Plugin): class ReactBot(Plugin):
rules: Dict[str, Rule] rules: Dict[str, Rule]
templates: Dict[str, Template] templates: Dict[str, Template]
default_flags: re.RegexFlag
@classmethod @classmethod
def get_config_class(cls) -> Type[BaseProxyConfig]: def get_config_class(cls) -> Type[BaseProxyConfig]:
@ -178,21 +192,52 @@ class ReactBot(Plugin):
except Exception as e: except Exception as e:
raise ConfigError(f"Failed to load {name}") from e raise ConfigError(f"Failed to load {name}") from e
@staticmethod def _get_flags(self, flags: str) -> re.RegexFlag:
def _compile(pattern: str) -> RPattern: output = self.default_flags
esc = re.escape(pattern) for flag in flags:
if esc == pattern: flag = flag.lower()
return SimplePattern(lambda val: pattern in val) if flag == "i" or flag == "ignorecase":
elif pattern[0] == '^' and esc == f"\\^{pattern}": output |= re.IGNORECASE
pattern = pattern[1:] elif flag == "s" or flag == "dotall":
return SimplePattern(lambda val: val.startswith(pattern)) output |= re.DOTALL
elif pattern[-1] == '$' and esc == f"{pattern}\\$": elif flag == "x" or flag == "verbose":
pattern = pattern[:-1] output |= re.VERBOSE
return SimplePattern(lambda val: val.endswith(pattern)) elif flag == "m" or flag == "multiline":
elif pattern[0] == '^' and pattern[-1] == '$' and esc == f"\\^{pattern}\\$": output |= re.MULTILINE
pattern = pattern[1:-1] elif flag == "l" or flag == "locale":
return SimplePattern(lambda val: val == pattern) output |= re.LOCALE
return re.compile(pattern) elif flag == "u" or flag == "unicode":
output |= re.UNICODE
elif flag == "a" or flag == "ascii":
output |= re.ASCII
return output
def _compile(self, pattern: str) -> RPattern:
flags = self.default_flags
raw = False
if isinstance(pattern, dict):
flags = self._get_flags(pattern.get("flags", ""))
pattern = pattern["pattern"]
raw = pattern.get("raw", False)
if not flags or flags == re.IGNORECASE:
ignorecase = flags == re.IGNORECASE
s_pattern = pattern.lower() if ignorecase else pattern
esc = ""
if not raw:
esc = re.escape(pattern)
first, last = pattern[0], pattern[-1]
if first == '^' and last == '$' and (raw or esc == f"\\^{pattern[1:-1]}\\$"):
s_pattern = s_pattern[1:-1]
return SimplePattern(lambda val: val == s_pattern, ignorecase=ignorecase)
elif first == '^' and (raw or esc == f"\\^{pattern[1:]}"):
s_pattern = s_pattern[1:]
return SimplePattern(lambda val: val.startswith(s_pattern), ignorecase=ignorecase)
elif last == '$' and (raw or esc == f"{pattern[:-1]}\\$"):
s_pattern = s_pattern[:-1]
return SimplePattern(lambda val: val.endswith(s_pattern), ignorecase=ignorecase)
elif raw or esc == pattern:
return SimplePattern(lambda val: s_pattern in val, ignorecase=ignorecase)
return re.compile(pattern, flags=flags)
def _compile_all(self, patterns: List[str]) -> List[RPattern]: def _compile_all(self, patterns: List[str]) -> List[RPattern]:
return [self._compile(pattern) for pattern in patterns] return [self._compile(pattern) for pattern in patterns]
@ -211,6 +256,8 @@ class ReactBot(Plugin):
def on_external_config_update(self) -> None: def on_external_config_update(self) -> None:
self.config.load_and_update() self.config.load_and_update()
try: try:
self.default_flags = re.RegexFlag(0)
self.default_flags = self._get_flags(self.config["default_flags"])
self.templates = {name: self._make_template(name, tpl) self.templates = {name: self._make_template(name, tpl)
for name, tpl in self.config["templates"].items()} for name, tpl in self.config["templates"].items()}
self.rules = {name: self._make_rule(name, rule) self.rules = {name: self._make_rule(name, rule)